YOLO for African Wildlife Object Detection¶
In [1]:
import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
random.seed(33)
# Define dataset paths
dataset_path = "data/wildlife/valid"
image_folder = os.path.join(dataset_path, "images")
label_folder = os.path.join(dataset_path, "labels")
# Class labels dictionary
class_labels = {0: "Buffalo", 1: "Elephant", 2: "Rhino", 3: "Zebra"}
colors = {0: (255, 0, 0), 1: (0, 255, 0), 2: (0, 0, 255), 3: (255, 255, 0)} # Colors for each class
# Get list of image files
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.jpg', '.png'))]
# Select 6 random images
selected_images = random.sample(image_files, min(9, len(image_files)))
def draw_bboxes(image_path, label_path):
"""Draw bounding boxes on an image using YOLO annotations."""
# Load image
image = cv2.imread(image_path)
height, width, _ = image.shape
# Read YOLO label file
with open(label_path, "r") as f:
lines = f.readlines()
# Draw bounding boxes
for line in lines:
class_id, x_center, y_center, bbox_width, bbox_height = map(float, line.strip().split())
# Convert YOLO format (normalized) to pixel values
x_center, y_center = int(x_center * width), int(y_center * height)
bbox_width, bbox_height = int(bbox_width * width), int(bbox_height * height)
x_min, y_min = x_center - bbox_width // 2, y_center - bbox_height // 2
x_max, y_max = x_center + bbox_width // 2, y_center + bbox_height // 2
# Draw rectangle
color = colors[int(class_id)]
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 2)
# Add class label
label_text = f"{class_labels[int(class_id)]}"
cv2.putText(image, label_text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB for matplotlib
# Plot 6 images with bounding boxes
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
for ax, image_file in zip(axes.flatten(), selected_images):
image_path = os.path.join(image_folder, image_file)
label_path = os.path.join(label_folder, image_file.replace(".jpg", ".txt").replace(".png", ".txt"))
if os.path.exists(label_path):
processed_image = draw_bboxes(image_path, label_path)
ax.imshow(processed_image)
ax.set_title(image_file)
ax.axis("off")
else:
ax.axis("off")
plt.tight_layout()
plt.show()
In [2]:
import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO
# Load YOLOv11 pre-trained model
model = YOLO("yolo11s.pt")
# Function to draw predicted bounding boxes
def draw_predictions(image_path):
"""Runs YOLOv11 inference and draws bounding boxes on the image."""
# Load image
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Run YOLOv11 model on the image
results = model(image_rgb)
# Draw bounding boxes
for result in results:
boxes = result.boxes.xyxy # Bounding boxes (x1, y1, x2, y2)
scores = result.boxes.conf # Confidence scores
labels = result.boxes.cls # Class labels
for i, box in enumerate(boxes):
x1, y1, x2, y2 = map(int, box) # Convert to integers
label = model.names[int(labels[i])]
score = scores[i]
# Draw bounding box
cv2.rectangle(image_rgb, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.putText(image_rgb, f"{label} {score:.2f}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
return image_rgb # Return processed image
# Plot 6 images with YOLOv11 predictions
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
for ax, image_file in zip(axes.flatten(), selected_images):
image_path = os.path.join(image_folder, image_file)
# Process image with YOLOv11
predicted_image = draw_predictions(image_path)
# Display image
ax.imshow(predicted_image)
ax.set_title(f"Predictions for {image_file}")
ax.axis("off")
plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.) return F.conv2d(input, weight, bias, self.stride,
0: 448x640 1 elephant, 90.7ms Speed: 3.8ms preprocess, 90.7ms inference, 57.4ms postprocess per image at shape (1, 3, 448, 640) 0: 416x640 1 zebra, 87.5ms Speed: 0.9ms preprocess, 87.5ms inference, 1.4ms postprocess per image at shape (1, 3, 416, 640) 0: 480x640 1 cow, 86.2ms Speed: 2.3ms preprocess, 86.2ms inference, 0.7ms postprocess per image at shape (1, 3, 480, 640) 0: 512x640 2 zebras, 85.7ms Speed: 2.6ms preprocess, 85.7ms inference, 0.9ms postprocess per image at shape (1, 3, 512, 640) 0: 448x640 1 person, 2 cows, 5.7ms Speed: 1.1ms preprocess, 5.7ms inference, 0.6ms postprocess per image at shape (1, 3, 448, 640) 0: 448x640 1 cow, 5.5ms Speed: 1.7ms preprocess, 5.5ms inference, 0.9ms postprocess per image at shape (1, 3, 448, 640) 0: 448x640 1 zebra, 5.2ms Speed: 1.3ms preprocess, 5.2ms inference, 0.5ms postprocess per image at shape (1, 3, 448, 640) 0: 640x640 4 cows, 2 elephants, 5.5ms Speed: 0.9ms preprocess, 5.5ms inference, 0.6ms postprocess per image at shape (1, 3, 640, 640) 0: 640x448 1 elephant, 86.3ms Speed: 0.7ms preprocess, 86.3ms inference, 0.6ms postprocess per image at shape (1, 3, 640, 448)
Fine Tuning¶
In [5]:
from ultralytics import YOLO
# Load a pre-trained YOLOv11 model
model = YOLO("yolo11s.pt") # Use "yolov11s.pt" for better accuracy
# Train the model using your dataset
model.train(
data="wildlife.yaml", # Path to your dataset config
epochs=25, # Number of training epochs
imgsz=640, # Image size
batch=32, # Batch size (adjust based on GPU memory)
workers=12, # Number of CPU workers
device="cuda" # Use GPU if available, otherwise use "cpu"
)
New https://pypi.org/project/ultralytics/8.3.83 available 😃 Update with 'pip install -U ultralytics' Ultralytics 8.3.82 🚀 Python-3.10.12 torch-2.1.2+cu118 CUDA:0 (NVIDIA GeForce RTX 3090 Ti, 24245MiB) engine/trainer: task=detect, mode=train, model=yolo11s.pt, data=wildlife.yaml, epochs=25, time=None, patience=100, batch=32, imgsz=640, save=True, save_period=-1, cache=False, device=cuda, workers=12, project=None, name=train, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_conf=True, show_boxes=True, line_width=None, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=True, opset=None, workspace=None, nms=False, lr0=0.01, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, pose=12.0, kobj=1.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, bgr=0.0, mosaic=1.0, mixup=0.0, copy_paste=0.0, copy_paste_mode=flip, auto_augment=randaugment, erasing=0.4, crop_fraction=1.0, cfg=None, tracker=botsort.yaml, save_dir=runs/detect/train Overriding model.yaml nc=80 with nc=4 from n params module arguments 0 -1 1 928 ultralytics.nn.modules.conv.Conv [3, 32, 3, 2] 1 -1 1 18560 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2] 2 -1 1 26080 ultralytics.nn.modules.block.C3k2 [64, 128, 1, False, 0.25] 3 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2] 4 -1 1 103360 ultralytics.nn.modules.block.C3k2 [128, 256, 1, False, 0.25] 5 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2] 6 -1 1 346112 ultralytics.nn.modules.block.C3k2 [256, 256, 1, True] 7 -1 1 1180672 ultralytics.nn.modules.conv.Conv [256, 512, 3, 2] 8 -1 1 1380352 ultralytics.nn.modules.block.C3k2 [512, 512, 1, True] 9 -1 1 656896 ultralytics.nn.modules.block.SPPF [512, 512, 5] 10 -1 1 990976 ultralytics.nn.modules.block.C2PSA [512, 512, 1] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 12 [-1, 6] 1 0 ultralytics.nn.modules.conv.Concat [1] 13 -1 1 443776 ultralytics.nn.modules.block.C3k2 [768, 256, 1, False] 14 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 15 [-1, 4] 1 0 ultralytics.nn.modules.conv.Concat [1] 16 -1 1 127680 ultralytics.nn.modules.block.C3k2 [512, 128, 1, False] 17 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2] 18 [-1, 13] 1 0 ultralytics.nn.modules.conv.Concat [1] 19 -1 1 345472 ultralytics.nn.modules.block.C3k2 [384, 256, 1, False] 20 -1 1 590336 ultralytics.nn.modules.conv.Conv [256, 256, 3, 2] 21 [-1, 10] 1 0 ultralytics.nn.modules.conv.Concat [1] 22 -1 1 1511424 ultralytics.nn.modules.block.C3k2 [768, 512, 1, True] 23 [16, 19, 22] 1 820956 ultralytics.nn.modules.head.Detect [4, [128, 256, 512]] YOLO11s summary: 181 layers, 9,429,340 parameters, 9,429,324 gradients, 21.6 GFLOPs Transferred 493/499 items from pretrained weights WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. start() got an unexpected keyword argument 'project_name' WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. ----NeptuneMissingApiTokenException------------------------------------------- The Neptune client couldn't find your API token. You can get it here: - https://app.neptune.ai/get_my_api_token There are two options to add it: - specify it in your code - set an environment variable in your operating system. CODE Pass the token to the init_run() function via the api_token argument: neptune.init_run(project='WORKSPACE_NAME/PROJECT_NAME', api_token='YOUR_API_TOKEN') ENVIRONMENT VARIABLE (Recommended option) or export or set an environment variable depending on your operating system: Linux/Unix In your terminal run: export NEPTUNE_API_TOKEN="YOUR_API_TOKEN" Windows In your CMD run: set NEPTUNE_API_TOKEN="YOUR_API_TOKEN" and skip the api_token argument of the init_run() function: neptune.init_run(project='WORKSPACE_NAME/PROJECT_NAME') You may also want to check the following docs pages: - https://docs.neptune.ai/setup/setting_api_token/ Need help?-> https://docs.neptune.ai/getting_help TensorBoard: Start with 'tensorboard --logdir runs/detect/train', view at http://localhost:6006/ Freezing layer 'model.23.dfl.conv.weight' AMP: running Automatic Mixed Precision (AMP) checks...
[neptune] [warning] NeptuneWarning: The following monitoring options are disabled by default in interactive sessions: 'capture_stdout', 'capture_stderr', 'capture_traceback', and 'capture_hardware_metrics'. To enable them, set each parameter to 'True' when initializing the run. The monitoring will continue until you call run.stop() or the kernel stops. Also note: Your source files can only be tracked if you pass the path(s) to the 'source_code' argument. For help, see the Neptune docs: https://docs.neptune.ai/logging/source_code/
AMP: checks passed ✅
train: Scanning /home/kmcalist/QTM447/Spring2025/Lectures/Lecture15/data/wildlife/train/labels.cache... 1052 images, 0 backgrounds, 0 corrupt: 100%|██████████| 1052/1052 [00:00<?, ?it/s] val: Scanning /home/kmcalist/QTM447/Spring2025/Lectures/Lecture15/data/wildlife/valid/labels.cache... 225 images, 0 backgrounds, 0 corrupt: 100%|██████████| 225/225 [00:00<?, ?it/s]
Plotting labels to runs/detect/train/labels.jpg... optimizer: 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... optimizer: AdamW(lr=0.00125, momentum=0.9) with parameter groups 81 weight(decay=0.0), 88 weight(decay=0.0005), 87 bias(decay=0.0) TensorBoard: model graph visualization added ✅ Image sizes 640 train, 640 val Using 12 dataloader workers Logging results to runs/detect/train Starting training for 25 epochs... Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
1/25 8.42G 0.8596 2.283 1.256 98 640: 100%|██████████| 33/33 [00:06<00:00, 5.11it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:01<00:00, 3.56it/s]
all 225 379 0.604 0.604 0.691 0.472
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
2/25 8.18G 1.008 1.331 1.328 115 640: 100%|██████████| 33/33 [00:05<00:00, 6.07it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.54it/s]
all 225 379 0.608 0.555 0.595 0.375
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
3/25 8.48G 1.023 1.223 1.338 114 640: 100%|██████████| 33/33 [00:05<00:00, 6.28it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 5.38it/s]
all 225 379 0.48 0.335 0.32 0.175
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
4/25 8.06G 1.083 1.212 1.383 109 640: 100%|██████████| 33/33 [00:05<00:00, 6.19it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.50it/s]
all 225 379 0.655 0.209 0.219 0.126
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
5/25 8.35G 1.075 1.181 1.372 106 640: 100%|██████████| 33/33 [00:05<00:00, 6.13it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.47it/s]
all 225 379 0.475 0.452 0.433 0.255
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
6/25 8.2G 1.017 1.091 1.321 115 640: 100%|██████████| 33/33 [00:05<00:00, 6.18it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.62it/s]
all 225 379 0.704 0.66 0.735 0.499
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
7/25 8.47G 1.019 1.076 1.333 97 640: 100%|██████████| 33/33 [00:05<00:00, 6.18it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.45it/s]
all 225 379 0.815 0.541 0.72 0.516
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
8/25 8.19G 0.9524 1.011 1.281 127 640: 100%|██████████| 33/33 [00:05<00:00, 6.14it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.49it/s]
all 225 379 0.849 0.682 0.809 0.573
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
9/25 8.47G 0.9336 0.9513 1.265 142 640: 100%|██████████| 33/33 [00:05<00:00, 6.11it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.37it/s]
all 225 379 0.788 0.66 0.779 0.551
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
10/25 8.19G 0.902 0.8862 1.253 111 640: 100%|██████████| 33/33 [00:05<00:00, 6.19it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.61it/s]
all 225 379 0.8 0.743 0.822 0.604
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
11/25 8.47G 0.8912 0.8825 1.244 132 640: 100%|██████████| 33/33 [00:05<00:00, 6.17it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.49it/s]
all 225 379 0.831 0.761 0.862 0.624
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
12/25 8.19G 0.8687 0.8668 1.221 110 640: 100%|██████████| 33/33 [00:05<00:00, 6.19it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.62it/s]
all 225 379 0.871 0.77 0.873 0.649
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
13/25 8.47G 0.8227 0.7707 1.188 122 640: 100%|██████████| 33/33 [00:05<00:00, 6.12it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.53it/s]
all 225 379 0.864 0.827 0.892 0.694
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
14/25 8.1G 0.811 0.7525 1.182 112 640: 100%|██████████| 33/33 [00:05<00:00, 6.06it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.41it/s]
all 225 379 0.812 0.81 0.875 0.676
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
15/25 8.47G 0.8019 0.7402 1.179 133 640: 100%|██████████| 33/33 [00:05<00:00, 6.13it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.61it/s]
all 225 379 0.94 0.842 0.921 0.728
Closing dataloader mosaic
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
16/25 8.19G 0.7612 0.6495 1.171 40 640: 100%|██████████| 33/33 [00:05<00:00, 5.59it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.59it/s]
all 225 379 0.918 0.837 0.919 0.723
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
17/25 8.47G 0.7296 0.631 1.137 42 640: 100%|██████████| 33/33 [00:05<00:00, 6.13it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.63it/s]
all 225 379 0.749 0.792 0.823 0.633
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
18/25 8.19G 0.695 0.599 1.109 54 640: 100%|██████████| 33/33 [00:05<00:00, 6.17it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.62it/s]
all 225 379 0.902 0.789 0.903 0.724
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
19/25 8.47G 0.6755 0.5599 1.1 53 640: 100%|██████████| 33/33 [00:05<00:00, 6.21it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.47it/s]
all 225 379 0.929 0.848 0.925 0.757
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
20/25 8.19G 0.6638 0.5076 1.093 43 640: 100%|██████████| 33/33 [00:05<00:00, 6.20it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.57it/s]
all 225 379 0.864 0.863 0.919 0.758
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
21/25 8.47G 0.634 0.4936 1.06 51 640: 100%|██████████| 33/33 [00:05<00:00, 6.28it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.74it/s]
all 225 379 0.915 0.851 0.933 0.761
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
22/25 8.19G 0.6195 0.4777 1.062 53 640: 100%|██████████| 33/33 [00:05<00:00, 6.19it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.50it/s]
all 225 379 0.897 0.875 0.932 0.776
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
23/25 8.47G 0.6009 0.4635 1.045 43 640: 100%|██████████| 33/33 [00:05<00:00, 6.21it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.73it/s]
all 225 379 0.944 0.882 0.95 0.797
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
24/25 8.19G 0.5726 0.428 1.03 46 640: 100%|██████████| 33/33 [00:05<00:00, 6.20it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.33it/s]
all 225 379 0.935 0.883 0.952 0.8
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
25/25 8.47G 0.5576 0.3874 1.003 46 640: 100%|██████████| 33/33 [00:05<00:00, 6.18it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 6.63it/s]
all 225 379 0.938 0.885 0.951 0.807
25 epochs completed in 0.048 hours. Optimizer stripped from runs/detect/train/weights/last.pt, 19.2MB Optimizer stripped from runs/detect/train/weights/best.pt, 19.2MB Validating runs/detect/train/weights/best.pt... Ultralytics 8.3.82 🚀 Python-3.10.12 torch-2.1.2+cu118 CUDA:0 (NVIDIA GeForce RTX 3090 Ti, 24245MiB) YOLO11s summary (fused): 100 layers, 9,414,348 parameters, 0 gradients, 21.3 GFLOPs
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 4/4 [00:00<00:00, 4.62it/s]
all 225 379 0.938 0.885 0.951 0.807
buffalo 62 89 0.949 0.888 0.955 0.848
elephant 53 91 0.894 0.868 0.94 0.761
rhino 55 85 0.941 0.943 0.966 0.867
zebra 59 114 0.97 0.842 0.943 0.751
Speed: 0.1ms preprocess, 1.2ms inference, 0.0ms loss, 0.7ms postprocess per image
Results saved to runs/detect/train
Out[5]:
ultralytics.utils.metrics.DetMetrics object with attributes:
ap_class_index: array([0, 1, 2, 3])
box: ultralytics.utils.metrics.Metric object
confusion_matrix: <ultralytics.utils.metrics.ConfusionMatrix object at 0x7835f50be4d0>
curves: ['Precision-Recall(B)', 'F1-Confidence(B)', 'Precision-Confidence(B)', 'Recall-Confidence(B)']
curves_results: [[array([ 0, 0.001001, 0.002002, 0.003003, 0.004004, 0.005005, 0.006006, 0.007007, 0.008008, 0.009009, 0.01001, 0.011011, 0.012012, 0.013013, 0.014014, 0.015015, 0.016016, 0.017017, 0.018018, 0.019019, 0.02002, 0.021021, 0.022022, 0.023023,
0.024024, 0.025025, 0.026026, 0.027027, 0.028028, 0.029029, 0.03003, 0.031031, 0.032032, 0.033033, 0.034034, 0.035035, 0.036036, 0.037037, 0.038038, 0.039039, 0.04004, 0.041041, 0.042042, 0.043043, 0.044044, 0.045045, 0.046046, 0.047047,
0.048048, 0.049049, 0.05005, 0.051051, 0.052052, 0.053053, 0.054054, 0.055055, 0.056056, 0.057057, 0.058058, 0.059059, 0.06006, 0.061061, 0.062062, 0.063063, 0.064064, 0.065065, 0.066066, 0.067067, 0.068068, 0.069069, 0.07007, 0.071071,
0.072072, 0.073073, 0.074074, 0.075075, 0.076076, 0.077077, 0.078078, 0.079079, 0.08008, 0.081081, 0.082082, 0.083083, 0.084084, 0.085085, 0.086086, 0.087087, 0.088088, 0.089089, 0.09009, 0.091091, 0.092092, 0.093093, 0.094094, 0.095095,
0.096096, 0.097097, 0.098098, 0.099099, 0.1001, 0.1011, 0.1021, 0.1031, 0.1041, 0.10511, 0.10611, 0.10711, 0.10811, 0.10911, 0.11011, 0.11111, 0.11211, 0.11311, 0.11411, 0.11512, 0.11612, 0.11712, 0.11812, 0.11912,
0.12012, 0.12112, 0.12212, 0.12312, 0.12412, 0.12513, 0.12613, 0.12713, 0.12813, 0.12913, 0.13013, 0.13113, 0.13213, 0.13313, 0.13413, 0.13514, 0.13614, 0.13714, 0.13814, 0.13914, 0.14014, 0.14114, 0.14214, 0.14314,
0.14414, 0.14515, 0.14615, 0.14715, 0.14815, 0.14915, 0.15015, 0.15115, 0.15215, 0.15315, 0.15415, 0.15516, 0.15616, 0.15716, 0.15816, 0.15916, 0.16016, 0.16116, 0.16216, 0.16316, 0.16416, 0.16517, 0.16617, 0.16717,
0.16817, 0.16917, 0.17017, 0.17117, 0.17217, 0.17317, 0.17417, 0.17518, 0.17618, 0.17718, 0.17818, 0.17918, 0.18018, 0.18118, 0.18218, 0.18318, 0.18418, 0.18519, 0.18619, 0.18719, 0.18819, 0.18919, 0.19019, 0.19119,
0.19219, 0.19319, 0.19419, 0.1952, 0.1962, 0.1972, 0.1982, 0.1992, 0.2002, 0.2012, 0.2022, 0.2032, 0.2042, 0.20521, 0.20621, 0.20721, 0.20821, 0.20921, 0.21021, 0.21121, 0.21221, 0.21321, 0.21421, 0.21522,
0.21622, 0.21722, 0.21822, 0.21922, 0.22022, 0.22122, 0.22222, 0.22322, 0.22422, 0.22523, 0.22623, 0.22723, 0.22823, 0.22923, 0.23023, 0.23123, 0.23223, 0.23323, 0.23423, 0.23524, 0.23624, 0.23724, 0.23824, 0.23924,
0.24024, 0.24124, 0.24224, 0.24324, 0.24424, 0.24525, 0.24625, 0.24725, 0.24825, 0.24925, 0.25025, 0.25125, 0.25225, 0.25325, 0.25425, 0.25526, 0.25626, 0.25726, 0.25826, 0.25926, 0.26026, 0.26126, 0.26226, 0.26326,
0.26426, 0.26527, 0.26627, 0.26727, 0.26827, 0.26927, 0.27027, 0.27127, 0.27227, 0.27327, 0.27427, 0.27528, 0.27628, 0.27728, 0.27828, 0.27928, 0.28028, 0.28128, 0.28228, 0.28328, 0.28428, 0.28529, 0.28629, 0.28729,
0.28829, 0.28929, 0.29029, 0.29129, 0.29229, 0.29329, 0.29429, 0.2953, 0.2963, 0.2973, 0.2983, 0.2993, 0.3003, 0.3013, 0.3023, 0.3033, 0.3043, 0.30531, 0.30631, 0.30731, 0.30831, 0.30931, 0.31031, 0.31131,
0.31231, 0.31331, 0.31431, 0.31532, 0.31632, 0.31732, 0.31832, 0.31932, 0.32032, 0.32132, 0.32232, 0.32332, 0.32432, 0.32533, 0.32633, 0.32733, 0.32833, 0.32933, 0.33033, 0.33133, 0.33233, 0.33333, 0.33433, 0.33534,
0.33634, 0.33734, 0.33834, 0.33934, 0.34034, 0.34134, 0.34234, 0.34334, 0.34434, 0.34535, 0.34635, 0.34735, 0.34835, 0.34935, 0.35035, 0.35135, 0.35235, 0.35335, 0.35435, 0.35536, 0.35636, 0.35736, 0.35836, 0.35936,
0.36036, 0.36136, 0.36236, 0.36336, 0.36436, 0.36537, 0.36637, 0.36737, 0.36837, 0.36937, 0.37037, 0.37137, 0.37237, 0.37337, 0.37437, 0.37538, 0.37638, 0.37738, 0.37838, 0.37938, 0.38038, 0.38138, 0.38238, 0.38338,
0.38438, 0.38539, 0.38639, 0.38739, 0.38839, 0.38939, 0.39039, 0.39139, 0.39239, 0.39339, 0.39439, 0.3954, 0.3964, 0.3974, 0.3984, 0.3994, 0.4004, 0.4014, 0.4024, 0.4034, 0.4044, 0.40541, 0.40641, 0.40741,
0.40841, 0.40941, 0.41041, 0.41141, 0.41241, 0.41341, 0.41441, 0.41542, 0.41642, 0.41742, 0.41842, 0.41942, 0.42042, 0.42142, 0.42242, 0.42342, 0.42442, 0.42543, 0.42643, 0.42743, 0.42843, 0.42943, 0.43043, 0.43143,
0.43243, 0.43343, 0.43443, 0.43544, 0.43644, 0.43744, 0.43844, 0.43944, 0.44044, 0.44144, 0.44244, 0.44344, 0.44444, 0.44545, 0.44645, 0.44745, 0.44845, 0.44945, 0.45045, 0.45145, 0.45245, 0.45345, 0.45445, 0.45546,
0.45646, 0.45746, 0.45846, 0.45946, 0.46046, 0.46146, 0.46246, 0.46346, 0.46446, 0.46547, 0.46647, 0.46747, 0.46847, 0.46947, 0.47047, 0.47147, 0.47247, 0.47347, 0.47447, 0.47548, 0.47648, 0.47748, 0.47848, 0.47948,
0.48048, 0.48148, 0.48248, 0.48348, 0.48448, 0.48549, 0.48649, 0.48749, 0.48849, 0.48949, 0.49049, 0.49149, 0.49249, 0.49349, 0.49449, 0.4955, 0.4965, 0.4975, 0.4985, 0.4995, 0.5005, 0.5015, 0.5025, 0.5035,
0.5045, 0.50551, 0.50651, 0.50751, 0.50851, 0.50951, 0.51051, 0.51151, 0.51251, 0.51351, 0.51451, 0.51552, 0.51652, 0.51752, 0.51852, 0.51952, 0.52052, 0.52152, 0.52252, 0.52352, 0.52452, 0.52553, 0.52653, 0.52753,
0.52853, 0.52953, 0.53053, 0.53153, 0.53253, 0.53353, 0.53453, 0.53554, 0.53654, 0.53754, 0.53854, 0.53954, 0.54054, 0.54154, 0.54254, 0.54354, 0.54454, 0.54555, 0.54655, 0.54755, 0.54855, 0.54955, 0.55055, 0.55155,
0.55255, 0.55355, 0.55455, 0.55556, 0.55656, 0.55756, 0.55856, 0.55956, 0.56056, 0.56156, 0.56256, 0.56356, 0.56456, 0.56557, 0.56657, 0.56757, 0.56857, 0.56957, 0.57057, 0.57157, 0.57257, 0.57357, 0.57457, 0.57558,
0.57658, 0.57758, 0.57858, 0.57958, 0.58058, 0.58158, 0.58258, 0.58358, 0.58458, 0.58559, 0.58659, 0.58759, 0.58859, 0.58959, 0.59059, 0.59159, 0.59259, 0.59359, 0.59459, 0.5956, 0.5966, 0.5976, 0.5986, 0.5996,
0.6006, 0.6016, 0.6026, 0.6036, 0.6046, 0.60561, 0.60661, 0.60761, 0.60861, 0.60961, 0.61061, 0.61161, 0.61261, 0.61361, 0.61461, 0.61562, 0.61662, 0.61762, 0.61862, 0.61962, 0.62062, 0.62162, 0.62262, 0.62362,
0.62462, 0.62563, 0.62663, 0.62763, 0.62863, 0.62963, 0.63063, 0.63163, 0.63263, 0.63363, 0.63463, 0.63564, 0.63664, 0.63764, 0.63864, 0.63964, 0.64064, 0.64164, 0.64264, 0.64364, 0.64464, 0.64565, 0.64665, 0.64765,
0.64865, 0.64965, 0.65065, 0.65165, 0.65265, 0.65365, 0.65465, 0.65566, 0.65666, 0.65766, 0.65866, 0.65966, 0.66066, 0.66166, 0.66266, 0.66366, 0.66466, 0.66567, 0.66667, 0.66767, 0.66867, 0.66967, 0.67067, 0.67167,
0.67267, 0.67367, 0.67467, 0.67568, 0.67668, 0.67768, 0.67868, 0.67968, 0.68068, 0.68168, 0.68268, 0.68368, 0.68468, 0.68569, 0.68669, 0.68769, 0.68869, 0.68969, 0.69069, 0.69169, 0.69269, 0.69369, 0.69469, 0.6957,
0.6967, 0.6977, 0.6987, 0.6997, 0.7007, 0.7017, 0.7027, 0.7037, 0.7047, 0.70571, 0.70671, 0.70771, 0.70871, 0.70971, 0.71071, 0.71171, 0.71271, 0.71371, 0.71471, 0.71572, 0.71672, 0.71772, 0.71872, 0.71972,
0.72072, 0.72172, 0.72272, 0.72372, 0.72472, 0.72573, 0.72673, 0.72773, 0.72873, 0.72973, 0.73073, 0.73173, 0.73273, 0.73373, 0.73473, 0.73574, 0.73674, 0.73774, 0.73874, 0.73974, 0.74074, 0.74174, 0.74274, 0.74374,
0.74474, 0.74575, 0.74675, 0.74775, 0.74875, 0.74975, 0.75075, 0.75175, 0.75275, 0.75375, 0.75475, 0.75576, 0.75676, 0.75776, 0.75876, 0.75976, 0.76076, 0.76176, 0.76276, 0.76376, 0.76476, 0.76577, 0.76677, 0.76777,
0.76877, 0.76977, 0.77077, 0.77177, 0.77277, 0.77377, 0.77477, 0.77578, 0.77678, 0.77778, 0.77878, 0.77978, 0.78078, 0.78178, 0.78278, 0.78378, 0.78478, 0.78579, 0.78679, 0.78779, 0.78879, 0.78979, 0.79079, 0.79179,
0.79279, 0.79379, 0.79479, 0.7958, 0.7968, 0.7978, 0.7988, 0.7998, 0.8008, 0.8018, 0.8028, 0.8038, 0.8048, 0.80581, 0.80681, 0.80781, 0.80881, 0.80981, 0.81081, 0.81181, 0.81281, 0.81381, 0.81481, 0.81582,
0.81682, 0.81782, 0.81882, 0.81982, 0.82082, 0.82182, 0.82282, 0.82382, 0.82482, 0.82583, 0.82683, 0.82783, 0.82883, 0.82983, 0.83083, 0.83183, 0.83283, 0.83383, 0.83483, 0.83584, 0.83684, 0.83784, 0.83884, 0.83984,
0.84084, 0.84184, 0.84284, 0.84384, 0.84484, 0.84585, 0.84685, 0.84785, 0.84885, 0.84985, 0.85085, 0.85185, 0.85285, 0.85385, 0.85485, 0.85586, 0.85686, 0.85786, 0.85886, 0.85986, 0.86086, 0.86186, 0.86286, 0.86386,
0.86486, 0.86587, 0.86687, 0.86787, 0.86887, 0.86987, 0.87087, 0.87187, 0.87287, 0.87387, 0.87487, 0.87588, 0.87688, 0.87788, 0.87888, 0.87988, 0.88088, 0.88188, 0.88288, 0.88388, 0.88488, 0.88589, 0.88689, 0.88789,
0.88889, 0.88989, 0.89089, 0.89189, 0.89289, 0.89389, 0.89489, 0.8959, 0.8969, 0.8979, 0.8989, 0.8999, 0.9009, 0.9019, 0.9029, 0.9039, 0.9049, 0.90591, 0.90691, 0.90791, 0.90891, 0.90991, 0.91091, 0.91191,
0.91291, 0.91391, 0.91491, 0.91592, 0.91692, 0.91792, 0.91892, 0.91992, 0.92092, 0.92192, 0.92292, 0.92392, 0.92492, 0.92593, 0.92693, 0.92793, 0.92893, 0.92993, 0.93093, 0.93193, 0.93293, 0.93393, 0.93493, 0.93594,
0.93694, 0.93794, 0.93894, 0.93994, 0.94094, 0.94194, 0.94294, 0.94394, 0.94494, 0.94595, 0.94695, 0.94795, 0.94895, 0.94995, 0.95095, 0.95195, 0.95295, 0.95395, 0.95495, 0.95596, 0.95696, 0.95796, 0.95896, 0.95996,
0.96096, 0.96196, 0.96296, 0.96396, 0.96496, 0.96597, 0.96697, 0.96797, 0.96897, 0.96997, 0.97097, 0.97197, 0.97297, 0.97397, 0.97497, 0.97598, 0.97698, 0.97798, 0.97898, 0.97998, 0.98098, 0.98198, 0.98298, 0.98398,
0.98498, 0.98599, 0.98699, 0.98799, 0.98899, 0.98999, 0.99099, 0.99199, 0.99299, 0.99399, 0.99499, 0.996, 0.997, 0.998, 0.999, 1]), array([[ 1, 1, 1, ..., 0.18737, 0.18737, 0],
[ 1, 1, 1, ..., 0.14725, 0.14725, 0],
[ 1, 1, 1, ..., 0.046713, 0.023357, 0],
[ 1, 1, 1, ..., 0.015325, 0.0076623, 0]]), 'Recall', 'Precision'], [array([ 0, 0.001001, 0.002002, 0.003003, 0.004004, 0.005005, 0.006006, 0.007007, 0.008008, 0.009009, 0.01001, 0.011011, 0.012012, 0.013013, 0.014014, 0.015015, 0.016016, 0.017017, 0.018018, 0.019019, 0.02002, 0.021021, 0.022022, 0.023023,
0.024024, 0.025025, 0.026026, 0.027027, 0.028028, 0.029029, 0.03003, 0.031031, 0.032032, 0.033033, 0.034034, 0.035035, 0.036036, 0.037037, 0.038038, 0.039039, 0.04004, 0.041041, 0.042042, 0.043043, 0.044044, 0.045045, 0.046046, 0.047047,
0.048048, 0.049049, 0.05005, 0.051051, 0.052052, 0.053053, 0.054054, 0.055055, 0.056056, 0.057057, 0.058058, 0.059059, 0.06006, 0.061061, 0.062062, 0.063063, 0.064064, 0.065065, 0.066066, 0.067067, 0.068068, 0.069069, 0.07007, 0.071071,
0.072072, 0.073073, 0.074074, 0.075075, 0.076076, 0.077077, 0.078078, 0.079079, 0.08008, 0.081081, 0.082082, 0.083083, 0.084084, 0.085085, 0.086086, 0.087087, 0.088088, 0.089089, 0.09009, 0.091091, 0.092092, 0.093093, 0.094094, 0.095095,
0.096096, 0.097097, 0.098098, 0.099099, 0.1001, 0.1011, 0.1021, 0.1031, 0.1041, 0.10511, 0.10611, 0.10711, 0.10811, 0.10911, 0.11011, 0.11111, 0.11211, 0.11311, 0.11411, 0.11512, 0.11612, 0.11712, 0.11812, 0.11912,
0.12012, 0.12112, 0.12212, 0.12312, 0.12412, 0.12513, 0.12613, 0.12713, 0.12813, 0.12913, 0.13013, 0.13113, 0.13213, 0.13313, 0.13413, 0.13514, 0.13614, 0.13714, 0.13814, 0.13914, 0.14014, 0.14114, 0.14214, 0.14314,
0.14414, 0.14515, 0.14615, 0.14715, 0.14815, 0.14915, 0.15015, 0.15115, 0.15215, 0.15315, 0.15415, 0.15516, 0.15616, 0.15716, 0.15816, 0.15916, 0.16016, 0.16116, 0.16216, 0.16316, 0.16416, 0.16517, 0.16617, 0.16717,
0.16817, 0.16917, 0.17017, 0.17117, 0.17217, 0.17317, 0.17417, 0.17518, 0.17618, 0.17718, 0.17818, 0.17918, 0.18018, 0.18118, 0.18218, 0.18318, 0.18418, 0.18519, 0.18619, 0.18719, 0.18819, 0.18919, 0.19019, 0.19119,
0.19219, 0.19319, 0.19419, 0.1952, 0.1962, 0.1972, 0.1982, 0.1992, 0.2002, 0.2012, 0.2022, 0.2032, 0.2042, 0.20521, 0.20621, 0.20721, 0.20821, 0.20921, 0.21021, 0.21121, 0.21221, 0.21321, 0.21421, 0.21522,
0.21622, 0.21722, 0.21822, 0.21922, 0.22022, 0.22122, 0.22222, 0.22322, 0.22422, 0.22523, 0.22623, 0.22723, 0.22823, 0.22923, 0.23023, 0.23123, 0.23223, 0.23323, 0.23423, 0.23524, 0.23624, 0.23724, 0.23824, 0.23924,
0.24024, 0.24124, 0.24224, 0.24324, 0.24424, 0.24525, 0.24625, 0.24725, 0.24825, 0.24925, 0.25025, 0.25125, 0.25225, 0.25325, 0.25425, 0.25526, 0.25626, 0.25726, 0.25826, 0.25926, 0.26026, 0.26126, 0.26226, 0.26326,
0.26426, 0.26527, 0.26627, 0.26727, 0.26827, 0.26927, 0.27027, 0.27127, 0.27227, 0.27327, 0.27427, 0.27528, 0.27628, 0.27728, 0.27828, 0.27928, 0.28028, 0.28128, 0.28228, 0.28328, 0.28428, 0.28529, 0.28629, 0.28729,
0.28829, 0.28929, 0.29029, 0.29129, 0.29229, 0.29329, 0.29429, 0.2953, 0.2963, 0.2973, 0.2983, 0.2993, 0.3003, 0.3013, 0.3023, 0.3033, 0.3043, 0.30531, 0.30631, 0.30731, 0.30831, 0.30931, 0.31031, 0.31131,
0.31231, 0.31331, 0.31431, 0.31532, 0.31632, 0.31732, 0.31832, 0.31932, 0.32032, 0.32132, 0.32232, 0.32332, 0.32432, 0.32533, 0.32633, 0.32733, 0.32833, 0.32933, 0.33033, 0.33133, 0.33233, 0.33333, 0.33433, 0.33534,
0.33634, 0.33734, 0.33834, 0.33934, 0.34034, 0.34134, 0.34234, 0.34334, 0.34434, 0.34535, 0.34635, 0.34735, 0.34835, 0.34935, 0.35035, 0.35135, 0.35235, 0.35335, 0.35435, 0.35536, 0.35636, 0.35736, 0.35836, 0.35936,
0.36036, 0.36136, 0.36236, 0.36336, 0.36436, 0.36537, 0.36637, 0.36737, 0.36837, 0.36937, 0.37037, 0.37137, 0.37237, 0.37337, 0.37437, 0.37538, 0.37638, 0.37738, 0.37838, 0.37938, 0.38038, 0.38138, 0.38238, 0.38338,
0.38438, 0.38539, 0.38639, 0.38739, 0.38839, 0.38939, 0.39039, 0.39139, 0.39239, 0.39339, 0.39439, 0.3954, 0.3964, 0.3974, 0.3984, 0.3994, 0.4004, 0.4014, 0.4024, 0.4034, 0.4044, 0.40541, 0.40641, 0.40741,
0.40841, 0.40941, 0.41041, 0.41141, 0.41241, 0.41341, 0.41441, 0.41542, 0.41642, 0.41742, 0.41842, 0.41942, 0.42042, 0.42142, 0.42242, 0.42342, 0.42442, 0.42543, 0.42643, 0.42743, 0.42843, 0.42943, 0.43043, 0.43143,
0.43243, 0.43343, 0.43443, 0.43544, 0.43644, 0.43744, 0.43844, 0.43944, 0.44044, 0.44144, 0.44244, 0.44344, 0.44444, 0.44545, 0.44645, 0.44745, 0.44845, 0.44945, 0.45045, 0.45145, 0.45245, 0.45345, 0.45445, 0.45546,
0.45646, 0.45746, 0.45846, 0.45946, 0.46046, 0.46146, 0.46246, 0.46346, 0.46446, 0.46547, 0.46647, 0.46747, 0.46847, 0.46947, 0.47047, 0.47147, 0.47247, 0.47347, 0.47447, 0.47548, 0.47648, 0.47748, 0.47848, 0.47948,
0.48048, 0.48148, 0.48248, 0.48348, 0.48448, 0.48549, 0.48649, 0.48749, 0.48849, 0.48949, 0.49049, 0.49149, 0.49249, 0.49349, 0.49449, 0.4955, 0.4965, 0.4975, 0.4985, 0.4995, 0.5005, 0.5015, 0.5025, 0.5035,
0.5045, 0.50551, 0.50651, 0.50751, 0.50851, 0.50951, 0.51051, 0.51151, 0.51251, 0.51351, 0.51451, 0.51552, 0.51652, 0.51752, 0.51852, 0.51952, 0.52052, 0.52152, 0.52252, 0.52352, 0.52452, 0.52553, 0.52653, 0.52753,
0.52853, 0.52953, 0.53053, 0.53153, 0.53253, 0.53353, 0.53453, 0.53554, 0.53654, 0.53754, 0.53854, 0.53954, 0.54054, 0.54154, 0.54254, 0.54354, 0.54454, 0.54555, 0.54655, 0.54755, 0.54855, 0.54955, 0.55055, 0.55155,
0.55255, 0.55355, 0.55455, 0.55556, 0.55656, 0.55756, 0.55856, 0.55956, 0.56056, 0.56156, 0.56256, 0.56356, 0.56456, 0.56557, 0.56657, 0.56757, 0.56857, 0.56957, 0.57057, 0.57157, 0.57257, 0.57357, 0.57457, 0.57558,
0.57658, 0.57758, 0.57858, 0.57958, 0.58058, 0.58158, 0.58258, 0.58358, 0.58458, 0.58559, 0.58659, 0.58759, 0.58859, 0.58959, 0.59059, 0.59159, 0.59259, 0.59359, 0.59459, 0.5956, 0.5966, 0.5976, 0.5986, 0.5996,
0.6006, 0.6016, 0.6026, 0.6036, 0.6046, 0.60561, 0.60661, 0.60761, 0.60861, 0.60961, 0.61061, 0.61161, 0.61261, 0.61361, 0.61461, 0.61562, 0.61662, 0.61762, 0.61862, 0.61962, 0.62062, 0.62162, 0.62262, 0.62362,
0.62462, 0.62563, 0.62663, 0.62763, 0.62863, 0.62963, 0.63063, 0.63163, 0.63263, 0.63363, 0.63463, 0.63564, 0.63664, 0.63764, 0.63864, 0.63964, 0.64064, 0.64164, 0.64264, 0.64364, 0.64464, 0.64565, 0.64665, 0.64765,
0.64865, 0.64965, 0.65065, 0.65165, 0.65265, 0.65365, 0.65465, 0.65566, 0.65666, 0.65766, 0.65866, 0.65966, 0.66066, 0.66166, 0.66266, 0.66366, 0.66466, 0.66567, 0.66667, 0.66767, 0.66867, 0.66967, 0.67067, 0.67167,
0.67267, 0.67367, 0.67467, 0.67568, 0.67668, 0.67768, 0.67868, 0.67968, 0.68068, 0.68168, 0.68268, 0.68368, 0.68468, 0.68569, 0.68669, 0.68769, 0.68869, 0.68969, 0.69069, 0.69169, 0.69269, 0.69369, 0.69469, 0.6957,
0.6967, 0.6977, 0.6987, 0.6997, 0.7007, 0.7017, 0.7027, 0.7037, 0.7047, 0.70571, 0.70671, 0.70771, 0.70871, 0.70971, 0.71071, 0.71171, 0.71271, 0.71371, 0.71471, 0.71572, 0.71672, 0.71772, 0.71872, 0.71972,
0.72072, 0.72172, 0.72272, 0.72372, 0.72472, 0.72573, 0.72673, 0.72773, 0.72873, 0.72973, 0.73073, 0.73173, 0.73273, 0.73373, 0.73473, 0.73574, 0.73674, 0.73774, 0.73874, 0.73974, 0.74074, 0.74174, 0.74274, 0.74374,
0.74474, 0.74575, 0.74675, 0.74775, 0.74875, 0.74975, 0.75075, 0.75175, 0.75275, 0.75375, 0.75475, 0.75576, 0.75676, 0.75776, 0.75876, 0.75976, 0.76076, 0.76176, 0.76276, 0.76376, 0.76476, 0.76577, 0.76677, 0.76777,
0.76877, 0.76977, 0.77077, 0.77177, 0.77277, 0.77377, 0.77477, 0.77578, 0.77678, 0.77778, 0.77878, 0.77978, 0.78078, 0.78178, 0.78278, 0.78378, 0.78478, 0.78579, 0.78679, 0.78779, 0.78879, 0.78979, 0.79079, 0.79179,
0.79279, 0.79379, 0.79479, 0.7958, 0.7968, 0.7978, 0.7988, 0.7998, 0.8008, 0.8018, 0.8028, 0.8038, 0.8048, 0.80581, 0.80681, 0.80781, 0.80881, 0.80981, 0.81081, 0.81181, 0.81281, 0.81381, 0.81481, 0.81582,
0.81682, 0.81782, 0.81882, 0.81982, 0.82082, 0.82182, 0.82282, 0.82382, 0.82482, 0.82583, 0.82683, 0.82783, 0.82883, 0.82983, 0.83083, 0.83183, 0.83283, 0.83383, 0.83483, 0.83584, 0.83684, 0.83784, 0.83884, 0.83984,
0.84084, 0.84184, 0.84284, 0.84384, 0.84484, 0.84585, 0.84685, 0.84785, 0.84885, 0.84985, 0.85085, 0.85185, 0.85285, 0.85385, 0.85485, 0.85586, 0.85686, 0.85786, 0.85886, 0.85986, 0.86086, 0.86186, 0.86286, 0.86386,
0.86486, 0.86587, 0.86687, 0.86787, 0.86887, 0.86987, 0.87087, 0.87187, 0.87287, 0.87387, 0.87487, 0.87588, 0.87688, 0.87788, 0.87888, 0.87988, 0.88088, 0.88188, 0.88288, 0.88388, 0.88488, 0.88589, 0.88689, 0.88789,
0.88889, 0.88989, 0.89089, 0.89189, 0.89289, 0.89389, 0.89489, 0.8959, 0.8969, 0.8979, 0.8989, 0.8999, 0.9009, 0.9019, 0.9029, 0.9039, 0.9049, 0.90591, 0.90691, 0.90791, 0.90891, 0.90991, 0.91091, 0.91191,
0.91291, 0.91391, 0.91491, 0.91592, 0.91692, 0.91792, 0.91892, 0.91992, 0.92092, 0.92192, 0.92292, 0.92392, 0.92492, 0.92593, 0.92693, 0.92793, 0.92893, 0.92993, 0.93093, 0.93193, 0.93293, 0.93393, 0.93493, 0.93594,
0.93694, 0.93794, 0.93894, 0.93994, 0.94094, 0.94194, 0.94294, 0.94394, 0.94494, 0.94595, 0.94695, 0.94795, 0.94895, 0.94995, 0.95095, 0.95195, 0.95295, 0.95395, 0.95495, 0.95596, 0.95696, 0.95796, 0.95896, 0.95996,
0.96096, 0.96196, 0.96296, 0.96396, 0.96496, 0.96597, 0.96697, 0.96797, 0.96897, 0.96997, 0.97097, 0.97197, 0.97297, 0.97397, 0.97497, 0.97598, 0.97698, 0.97798, 0.97898, 0.97998, 0.98098, 0.98198, 0.98298, 0.98398,
0.98498, 0.98599, 0.98699, 0.98799, 0.98899, 0.98999, 0.99099, 0.99199, 0.99299, 0.99399, 0.99499, 0.996, 0.997, 0.998, 0.999, 1]), array([[ 0.25284, 0.25284, 0.35988, ..., 0, 0, 0],
[ 0.20776, 0.20776, 0.30045, ..., 0, 0, 0],
[ 0.42967, 0.42967, 0.5577, ..., 0, 0, 0],
[ 0.23629, 0.23629, 0.30808, ..., 0, 0, 0]]), 'Confidence', 'F1'], [array([ 0, 0.001001, 0.002002, 0.003003, 0.004004, 0.005005, 0.006006, 0.007007, 0.008008, 0.009009, 0.01001, 0.011011, 0.012012, 0.013013, 0.014014, 0.015015, 0.016016, 0.017017, 0.018018, 0.019019, 0.02002, 0.021021, 0.022022, 0.023023,
0.024024, 0.025025, 0.026026, 0.027027, 0.028028, 0.029029, 0.03003, 0.031031, 0.032032, 0.033033, 0.034034, 0.035035, 0.036036, 0.037037, 0.038038, 0.039039, 0.04004, 0.041041, 0.042042, 0.043043, 0.044044, 0.045045, 0.046046, 0.047047,
0.048048, 0.049049, 0.05005, 0.051051, 0.052052, 0.053053, 0.054054, 0.055055, 0.056056, 0.057057, 0.058058, 0.059059, 0.06006, 0.061061, 0.062062, 0.063063, 0.064064, 0.065065, 0.066066, 0.067067, 0.068068, 0.069069, 0.07007, 0.071071,
0.072072, 0.073073, 0.074074, 0.075075, 0.076076, 0.077077, 0.078078, 0.079079, 0.08008, 0.081081, 0.082082, 0.083083, 0.084084, 0.085085, 0.086086, 0.087087, 0.088088, 0.089089, 0.09009, 0.091091, 0.092092, 0.093093, 0.094094, 0.095095,
0.096096, 0.097097, 0.098098, 0.099099, 0.1001, 0.1011, 0.1021, 0.1031, 0.1041, 0.10511, 0.10611, 0.10711, 0.10811, 0.10911, 0.11011, 0.11111, 0.11211, 0.11311, 0.11411, 0.11512, 0.11612, 0.11712, 0.11812, 0.11912,
0.12012, 0.12112, 0.12212, 0.12312, 0.12412, 0.12513, 0.12613, 0.12713, 0.12813, 0.12913, 0.13013, 0.13113, 0.13213, 0.13313, 0.13413, 0.13514, 0.13614, 0.13714, 0.13814, 0.13914, 0.14014, 0.14114, 0.14214, 0.14314,
0.14414, 0.14515, 0.14615, 0.14715, 0.14815, 0.14915, 0.15015, 0.15115, 0.15215, 0.15315, 0.15415, 0.15516, 0.15616, 0.15716, 0.15816, 0.15916, 0.16016, 0.16116, 0.16216, 0.16316, 0.16416, 0.16517, 0.16617, 0.16717,
0.16817, 0.16917, 0.17017, 0.17117, 0.17217, 0.17317, 0.17417, 0.17518, 0.17618, 0.17718, 0.17818, 0.17918, 0.18018, 0.18118, 0.18218, 0.18318, 0.18418, 0.18519, 0.18619, 0.18719, 0.18819, 0.18919, 0.19019, 0.19119,
0.19219, 0.19319, 0.19419, 0.1952, 0.1962, 0.1972, 0.1982, 0.1992, 0.2002, 0.2012, 0.2022, 0.2032, 0.2042, 0.20521, 0.20621, 0.20721, 0.20821, 0.20921, 0.21021, 0.21121, 0.21221, 0.21321, 0.21421, 0.21522,
0.21622, 0.21722, 0.21822, 0.21922, 0.22022, 0.22122, 0.22222, 0.22322, 0.22422, 0.22523, 0.22623, 0.22723, 0.22823, 0.22923, 0.23023, 0.23123, 0.23223, 0.23323, 0.23423, 0.23524, 0.23624, 0.23724, 0.23824, 0.23924,
0.24024, 0.24124, 0.24224, 0.24324, 0.24424, 0.24525, 0.24625, 0.24725, 0.24825, 0.24925, 0.25025, 0.25125, 0.25225, 0.25325, 0.25425, 0.25526, 0.25626, 0.25726, 0.25826, 0.25926, 0.26026, 0.26126, 0.26226, 0.26326,
0.26426, 0.26527, 0.26627, 0.26727, 0.26827, 0.26927, 0.27027, 0.27127, 0.27227, 0.27327, 0.27427, 0.27528, 0.27628, 0.27728, 0.27828, 0.27928, 0.28028, 0.28128, 0.28228, 0.28328, 0.28428, 0.28529, 0.28629, 0.28729,
0.28829, 0.28929, 0.29029, 0.29129, 0.29229, 0.29329, 0.29429, 0.2953, 0.2963, 0.2973, 0.2983, 0.2993, 0.3003, 0.3013, 0.3023, 0.3033, 0.3043, 0.30531, 0.30631, 0.30731, 0.30831, 0.30931, 0.31031, 0.31131,
0.31231, 0.31331, 0.31431, 0.31532, 0.31632, 0.31732, 0.31832, 0.31932, 0.32032, 0.32132, 0.32232, 0.32332, 0.32432, 0.32533, 0.32633, 0.32733, 0.32833, 0.32933, 0.33033, 0.33133, 0.33233, 0.33333, 0.33433, 0.33534,
0.33634, 0.33734, 0.33834, 0.33934, 0.34034, 0.34134, 0.34234, 0.34334, 0.34434, 0.34535, 0.34635, 0.34735, 0.34835, 0.34935, 0.35035, 0.35135, 0.35235, 0.35335, 0.35435, 0.35536, 0.35636, 0.35736, 0.35836, 0.35936,
0.36036, 0.36136, 0.36236, 0.36336, 0.36436, 0.36537, 0.36637, 0.36737, 0.36837, 0.36937, 0.37037, 0.37137, 0.37237, 0.37337, 0.37437, 0.37538, 0.37638, 0.37738, 0.37838, 0.37938, 0.38038, 0.38138, 0.38238, 0.38338,
0.38438, 0.38539, 0.38639, 0.38739, 0.38839, 0.38939, 0.39039, 0.39139, 0.39239, 0.39339, 0.39439, 0.3954, 0.3964, 0.3974, 0.3984, 0.3994, 0.4004, 0.4014, 0.4024, 0.4034, 0.4044, 0.40541, 0.40641, 0.40741,
0.40841, 0.40941, 0.41041, 0.41141, 0.41241, 0.41341, 0.41441, 0.41542, 0.41642, 0.41742, 0.41842, 0.41942, 0.42042, 0.42142, 0.42242, 0.42342, 0.42442, 0.42543, 0.42643, 0.42743, 0.42843, 0.42943, 0.43043, 0.43143,
0.43243, 0.43343, 0.43443, 0.43544, 0.43644, 0.43744, 0.43844, 0.43944, 0.44044, 0.44144, 0.44244, 0.44344, 0.44444, 0.44545, 0.44645, 0.44745, 0.44845, 0.44945, 0.45045, 0.45145, 0.45245, 0.45345, 0.45445, 0.45546,
0.45646, 0.45746, 0.45846, 0.45946, 0.46046, 0.46146, 0.46246, 0.46346, 0.46446, 0.46547, 0.46647, 0.46747, 0.46847, 0.46947, 0.47047, 0.47147, 0.47247, 0.47347, 0.47447, 0.47548, 0.47648, 0.47748, 0.47848, 0.47948,
0.48048, 0.48148, 0.48248, 0.48348, 0.48448, 0.48549, 0.48649, 0.48749, 0.48849, 0.48949, 0.49049, 0.49149, 0.49249, 0.49349, 0.49449, 0.4955, 0.4965, 0.4975, 0.4985, 0.4995, 0.5005, 0.5015, 0.5025, 0.5035,
0.5045, 0.50551, 0.50651, 0.50751, 0.50851, 0.50951, 0.51051, 0.51151, 0.51251, 0.51351, 0.51451, 0.51552, 0.51652, 0.51752, 0.51852, 0.51952, 0.52052, 0.52152, 0.52252, 0.52352, 0.52452, 0.52553, 0.52653, 0.52753,
0.52853, 0.52953, 0.53053, 0.53153, 0.53253, 0.53353, 0.53453, 0.53554, 0.53654, 0.53754, 0.53854, 0.53954, 0.54054, 0.54154, 0.54254, 0.54354, 0.54454, 0.54555, 0.54655, 0.54755, 0.54855, 0.54955, 0.55055, 0.55155,
0.55255, 0.55355, 0.55455, 0.55556, 0.55656, 0.55756, 0.55856, 0.55956, 0.56056, 0.56156, 0.56256, 0.56356, 0.56456, 0.56557, 0.56657, 0.56757, 0.56857, 0.56957, 0.57057, 0.57157, 0.57257, 0.57357, 0.57457, 0.57558,
0.57658, 0.57758, 0.57858, 0.57958, 0.58058, 0.58158, 0.58258, 0.58358, 0.58458, 0.58559, 0.58659, 0.58759, 0.58859, 0.58959, 0.59059, 0.59159, 0.59259, 0.59359, 0.59459, 0.5956, 0.5966, 0.5976, 0.5986, 0.5996,
0.6006, 0.6016, 0.6026, 0.6036, 0.6046, 0.60561, 0.60661, 0.60761, 0.60861, 0.60961, 0.61061, 0.61161, 0.61261, 0.61361, 0.61461, 0.61562, 0.61662, 0.61762, 0.61862, 0.61962, 0.62062, 0.62162, 0.62262, 0.62362,
0.62462, 0.62563, 0.62663, 0.62763, 0.62863, 0.62963, 0.63063, 0.63163, 0.63263, 0.63363, 0.63463, 0.63564, 0.63664, 0.63764, 0.63864, 0.63964, 0.64064, 0.64164, 0.64264, 0.64364, 0.64464, 0.64565, 0.64665, 0.64765,
0.64865, 0.64965, 0.65065, 0.65165, 0.65265, 0.65365, 0.65465, 0.65566, 0.65666, 0.65766, 0.65866, 0.65966, 0.66066, 0.66166, 0.66266, 0.66366, 0.66466, 0.66567, 0.66667, 0.66767, 0.66867, 0.66967, 0.67067, 0.67167,
0.67267, 0.67367, 0.67467, 0.67568, 0.67668, 0.67768, 0.67868, 0.67968, 0.68068, 0.68168, 0.68268, 0.68368, 0.68468, 0.68569, 0.68669, 0.68769, 0.68869, 0.68969, 0.69069, 0.69169, 0.69269, 0.69369, 0.69469, 0.6957,
0.6967, 0.6977, 0.6987, 0.6997, 0.7007, 0.7017, 0.7027, 0.7037, 0.7047, 0.70571, 0.70671, 0.70771, 0.70871, 0.70971, 0.71071, 0.71171, 0.71271, 0.71371, 0.71471, 0.71572, 0.71672, 0.71772, 0.71872, 0.71972,
0.72072, 0.72172, 0.72272, 0.72372, 0.72472, 0.72573, 0.72673, 0.72773, 0.72873, 0.72973, 0.73073, 0.73173, 0.73273, 0.73373, 0.73473, 0.73574, 0.73674, 0.73774, 0.73874, 0.73974, 0.74074, 0.74174, 0.74274, 0.74374,
0.74474, 0.74575, 0.74675, 0.74775, 0.74875, 0.74975, 0.75075, 0.75175, 0.75275, 0.75375, 0.75475, 0.75576, 0.75676, 0.75776, 0.75876, 0.75976, 0.76076, 0.76176, 0.76276, 0.76376, 0.76476, 0.76577, 0.76677, 0.76777,
0.76877, 0.76977, 0.77077, 0.77177, 0.77277, 0.77377, 0.77477, 0.77578, 0.77678, 0.77778, 0.77878, 0.77978, 0.78078, 0.78178, 0.78278, 0.78378, 0.78478, 0.78579, 0.78679, 0.78779, 0.78879, 0.78979, 0.79079, 0.79179,
0.79279, 0.79379, 0.79479, 0.7958, 0.7968, 0.7978, 0.7988, 0.7998, 0.8008, 0.8018, 0.8028, 0.8038, 0.8048, 0.80581, 0.80681, 0.80781, 0.80881, 0.80981, 0.81081, 0.81181, 0.81281, 0.81381, 0.81481, 0.81582,
0.81682, 0.81782, 0.81882, 0.81982, 0.82082, 0.82182, 0.82282, 0.82382, 0.82482, 0.82583, 0.82683, 0.82783, 0.82883, 0.82983, 0.83083, 0.83183, 0.83283, 0.83383, 0.83483, 0.83584, 0.83684, 0.83784, 0.83884, 0.83984,
0.84084, 0.84184, 0.84284, 0.84384, 0.84484, 0.84585, 0.84685, 0.84785, 0.84885, 0.84985, 0.85085, 0.85185, 0.85285, 0.85385, 0.85485, 0.85586, 0.85686, 0.85786, 0.85886, 0.85986, 0.86086, 0.86186, 0.86286, 0.86386,
0.86486, 0.86587, 0.86687, 0.86787, 0.86887, 0.86987, 0.87087, 0.87187, 0.87287, 0.87387, 0.87487, 0.87588, 0.87688, 0.87788, 0.87888, 0.87988, 0.88088, 0.88188, 0.88288, 0.88388, 0.88488, 0.88589, 0.88689, 0.88789,
0.88889, 0.88989, 0.89089, 0.89189, 0.89289, 0.89389, 0.89489, 0.8959, 0.8969, 0.8979, 0.8989, 0.8999, 0.9009, 0.9019, 0.9029, 0.9039, 0.9049, 0.90591, 0.90691, 0.90791, 0.90891, 0.90991, 0.91091, 0.91191,
0.91291, 0.91391, 0.91491, 0.91592, 0.91692, 0.91792, 0.91892, 0.91992, 0.92092, 0.92192, 0.92292, 0.92392, 0.92492, 0.92593, 0.92693, 0.92793, 0.92893, 0.92993, 0.93093, 0.93193, 0.93293, 0.93393, 0.93493, 0.93594,
0.93694, 0.93794, 0.93894, 0.93994, 0.94094, 0.94194, 0.94294, 0.94394, 0.94494, 0.94595, 0.94695, 0.94795, 0.94895, 0.94995, 0.95095, 0.95195, 0.95295, 0.95395, 0.95495, 0.95596, 0.95696, 0.95796, 0.95896, 0.95996,
0.96096, 0.96196, 0.96296, 0.96396, 0.96496, 0.96597, 0.96697, 0.96797, 0.96897, 0.96997, 0.97097, 0.97197, 0.97297, 0.97397, 0.97497, 0.97598, 0.97698, 0.97798, 0.97898, 0.97998, 0.98098, 0.98198, 0.98298, 0.98398,
0.98498, 0.98599, 0.98699, 0.98799, 0.98899, 0.98999, 0.99099, 0.99199, 0.99299, 0.99399, 0.99499, 0.996, 0.997, 0.998, 0.999, 1]), array([[ 0.14472, 0.14472, 0.21997, ..., 1, 1, 1],
[ 0.11592, 0.11592, 0.17713, ..., 1, 1, 1],
[ 0.27451, 0.27451, 0.38846, ..., 1, 1, 1],
[ 0.13429, 0.13429, 0.1833, ..., 1, 1, 1]]), 'Confidence', 'Precision'], [array([ 0, 0.001001, 0.002002, 0.003003, 0.004004, 0.005005, 0.006006, 0.007007, 0.008008, 0.009009, 0.01001, 0.011011, 0.012012, 0.013013, 0.014014, 0.015015, 0.016016, 0.017017, 0.018018, 0.019019, 0.02002, 0.021021, 0.022022, 0.023023,
0.024024, 0.025025, 0.026026, 0.027027, 0.028028, 0.029029, 0.03003, 0.031031, 0.032032, 0.033033, 0.034034, 0.035035, 0.036036, 0.037037, 0.038038, 0.039039, 0.04004, 0.041041, 0.042042, 0.043043, 0.044044, 0.045045, 0.046046, 0.047047,
0.048048, 0.049049, 0.05005, 0.051051, 0.052052, 0.053053, 0.054054, 0.055055, 0.056056, 0.057057, 0.058058, 0.059059, 0.06006, 0.061061, 0.062062, 0.063063, 0.064064, 0.065065, 0.066066, 0.067067, 0.068068, 0.069069, 0.07007, 0.071071,
0.072072, 0.073073, 0.074074, 0.075075, 0.076076, 0.077077, 0.078078, 0.079079, 0.08008, 0.081081, 0.082082, 0.083083, 0.084084, 0.085085, 0.086086, 0.087087, 0.088088, 0.089089, 0.09009, 0.091091, 0.092092, 0.093093, 0.094094, 0.095095,
0.096096, 0.097097, 0.098098, 0.099099, 0.1001, 0.1011, 0.1021, 0.1031, 0.1041, 0.10511, 0.10611, 0.10711, 0.10811, 0.10911, 0.11011, 0.11111, 0.11211, 0.11311, 0.11411, 0.11512, 0.11612, 0.11712, 0.11812, 0.11912,
0.12012, 0.12112, 0.12212, 0.12312, 0.12412, 0.12513, 0.12613, 0.12713, 0.12813, 0.12913, 0.13013, 0.13113, 0.13213, 0.13313, 0.13413, 0.13514, 0.13614, 0.13714, 0.13814, 0.13914, 0.14014, 0.14114, 0.14214, 0.14314,
0.14414, 0.14515, 0.14615, 0.14715, 0.14815, 0.14915, 0.15015, 0.15115, 0.15215, 0.15315, 0.15415, 0.15516, 0.15616, 0.15716, 0.15816, 0.15916, 0.16016, 0.16116, 0.16216, 0.16316, 0.16416, 0.16517, 0.16617, 0.16717,
0.16817, 0.16917, 0.17017, 0.17117, 0.17217, 0.17317, 0.17417, 0.17518, 0.17618, 0.17718, 0.17818, 0.17918, 0.18018, 0.18118, 0.18218, 0.18318, 0.18418, 0.18519, 0.18619, 0.18719, 0.18819, 0.18919, 0.19019, 0.19119,
0.19219, 0.19319, 0.19419, 0.1952, 0.1962, 0.1972, 0.1982, 0.1992, 0.2002, 0.2012, 0.2022, 0.2032, 0.2042, 0.20521, 0.20621, 0.20721, 0.20821, 0.20921, 0.21021, 0.21121, 0.21221, 0.21321, 0.21421, 0.21522,
0.21622, 0.21722, 0.21822, 0.21922, 0.22022, 0.22122, 0.22222, 0.22322, 0.22422, 0.22523, 0.22623, 0.22723, 0.22823, 0.22923, 0.23023, 0.23123, 0.23223, 0.23323, 0.23423, 0.23524, 0.23624, 0.23724, 0.23824, 0.23924,
0.24024, 0.24124, 0.24224, 0.24324, 0.24424, 0.24525, 0.24625, 0.24725, 0.24825, 0.24925, 0.25025, 0.25125, 0.25225, 0.25325, 0.25425, 0.25526, 0.25626, 0.25726, 0.25826, 0.25926, 0.26026, 0.26126, 0.26226, 0.26326,
0.26426, 0.26527, 0.26627, 0.26727, 0.26827, 0.26927, 0.27027, 0.27127, 0.27227, 0.27327, 0.27427, 0.27528, 0.27628, 0.27728, 0.27828, 0.27928, 0.28028, 0.28128, 0.28228, 0.28328, 0.28428, 0.28529, 0.28629, 0.28729,
0.28829, 0.28929, 0.29029, 0.29129, 0.29229, 0.29329, 0.29429, 0.2953, 0.2963, 0.2973, 0.2983, 0.2993, 0.3003, 0.3013, 0.3023, 0.3033, 0.3043, 0.30531, 0.30631, 0.30731, 0.30831, 0.30931, 0.31031, 0.31131,
0.31231, 0.31331, 0.31431, 0.31532, 0.31632, 0.31732, 0.31832, 0.31932, 0.32032, 0.32132, 0.32232, 0.32332, 0.32432, 0.32533, 0.32633, 0.32733, 0.32833, 0.32933, 0.33033, 0.33133, 0.33233, 0.33333, 0.33433, 0.33534,
0.33634, 0.33734, 0.33834, 0.33934, 0.34034, 0.34134, 0.34234, 0.34334, 0.34434, 0.34535, 0.34635, 0.34735, 0.34835, 0.34935, 0.35035, 0.35135, 0.35235, 0.35335, 0.35435, 0.35536, 0.35636, 0.35736, 0.35836, 0.35936,
0.36036, 0.36136, 0.36236, 0.36336, 0.36436, 0.36537, 0.36637, 0.36737, 0.36837, 0.36937, 0.37037, 0.37137, 0.37237, 0.37337, 0.37437, 0.37538, 0.37638, 0.37738, 0.37838, 0.37938, 0.38038, 0.38138, 0.38238, 0.38338,
0.38438, 0.38539, 0.38639, 0.38739, 0.38839, 0.38939, 0.39039, 0.39139, 0.39239, 0.39339, 0.39439, 0.3954, 0.3964, 0.3974, 0.3984, 0.3994, 0.4004, 0.4014, 0.4024, 0.4034, 0.4044, 0.40541, 0.40641, 0.40741,
0.40841, 0.40941, 0.41041, 0.41141, 0.41241, 0.41341, 0.41441, 0.41542, 0.41642, 0.41742, 0.41842, 0.41942, 0.42042, 0.42142, 0.42242, 0.42342, 0.42442, 0.42543, 0.42643, 0.42743, 0.42843, 0.42943, 0.43043, 0.43143,
0.43243, 0.43343, 0.43443, 0.43544, 0.43644, 0.43744, 0.43844, 0.43944, 0.44044, 0.44144, 0.44244, 0.44344, 0.44444, 0.44545, 0.44645, 0.44745, 0.44845, 0.44945, 0.45045, 0.45145, 0.45245, 0.45345, 0.45445, 0.45546,
0.45646, 0.45746, 0.45846, 0.45946, 0.46046, 0.46146, 0.46246, 0.46346, 0.46446, 0.46547, 0.46647, 0.46747, 0.46847, 0.46947, 0.47047, 0.47147, 0.47247, 0.47347, 0.47447, 0.47548, 0.47648, 0.47748, 0.47848, 0.47948,
0.48048, 0.48148, 0.48248, 0.48348, 0.48448, 0.48549, 0.48649, 0.48749, 0.48849, 0.48949, 0.49049, 0.49149, 0.49249, 0.49349, 0.49449, 0.4955, 0.4965, 0.4975, 0.4985, 0.4995, 0.5005, 0.5015, 0.5025, 0.5035,
0.5045, 0.50551, 0.50651, 0.50751, 0.50851, 0.50951, 0.51051, 0.51151, 0.51251, 0.51351, 0.51451, 0.51552, 0.51652, 0.51752, 0.51852, 0.51952, 0.52052, 0.52152, 0.52252, 0.52352, 0.52452, 0.52553, 0.52653, 0.52753,
0.52853, 0.52953, 0.53053, 0.53153, 0.53253, 0.53353, 0.53453, 0.53554, 0.53654, 0.53754, 0.53854, 0.53954, 0.54054, 0.54154, 0.54254, 0.54354, 0.54454, 0.54555, 0.54655, 0.54755, 0.54855, 0.54955, 0.55055, 0.55155,
0.55255, 0.55355, 0.55455, 0.55556, 0.55656, 0.55756, 0.55856, 0.55956, 0.56056, 0.56156, 0.56256, 0.56356, 0.56456, 0.56557, 0.56657, 0.56757, 0.56857, 0.56957, 0.57057, 0.57157, 0.57257, 0.57357, 0.57457, 0.57558,
0.57658, 0.57758, 0.57858, 0.57958, 0.58058, 0.58158, 0.58258, 0.58358, 0.58458, 0.58559, 0.58659, 0.58759, 0.58859, 0.58959, 0.59059, 0.59159, 0.59259, 0.59359, 0.59459, 0.5956, 0.5966, 0.5976, 0.5986, 0.5996,
0.6006, 0.6016, 0.6026, 0.6036, 0.6046, 0.60561, 0.60661, 0.60761, 0.60861, 0.60961, 0.61061, 0.61161, 0.61261, 0.61361, 0.61461, 0.61562, 0.61662, 0.61762, 0.61862, 0.61962, 0.62062, 0.62162, 0.62262, 0.62362,
0.62462, 0.62563, 0.62663, 0.62763, 0.62863, 0.62963, 0.63063, 0.63163, 0.63263, 0.63363, 0.63463, 0.63564, 0.63664, 0.63764, 0.63864, 0.63964, 0.64064, 0.64164, 0.64264, 0.64364, 0.64464, 0.64565, 0.64665, 0.64765,
0.64865, 0.64965, 0.65065, 0.65165, 0.65265, 0.65365, 0.65465, 0.65566, 0.65666, 0.65766, 0.65866, 0.65966, 0.66066, 0.66166, 0.66266, 0.66366, 0.66466, 0.66567, 0.66667, 0.66767, 0.66867, 0.66967, 0.67067, 0.67167,
0.67267, 0.67367, 0.67467, 0.67568, 0.67668, 0.67768, 0.67868, 0.67968, 0.68068, 0.68168, 0.68268, 0.68368, 0.68468, 0.68569, 0.68669, 0.68769, 0.68869, 0.68969, 0.69069, 0.69169, 0.69269, 0.69369, 0.69469, 0.6957,
0.6967, 0.6977, 0.6987, 0.6997, 0.7007, 0.7017, 0.7027, 0.7037, 0.7047, 0.70571, 0.70671, 0.70771, 0.70871, 0.70971, 0.71071, 0.71171, 0.71271, 0.71371, 0.71471, 0.71572, 0.71672, 0.71772, 0.71872, 0.71972,
0.72072, 0.72172, 0.72272, 0.72372, 0.72472, 0.72573, 0.72673, 0.72773, 0.72873, 0.72973, 0.73073, 0.73173, 0.73273, 0.73373, 0.73473, 0.73574, 0.73674, 0.73774, 0.73874, 0.73974, 0.74074, 0.74174, 0.74274, 0.74374,
0.74474, 0.74575, 0.74675, 0.74775, 0.74875, 0.74975, 0.75075, 0.75175, 0.75275, 0.75375, 0.75475, 0.75576, 0.75676, 0.75776, 0.75876, 0.75976, 0.76076, 0.76176, 0.76276, 0.76376, 0.76476, 0.76577, 0.76677, 0.76777,
0.76877, 0.76977, 0.77077, 0.77177, 0.77277, 0.77377, 0.77477, 0.77578, 0.77678, 0.77778, 0.77878, 0.77978, 0.78078, 0.78178, 0.78278, 0.78378, 0.78478, 0.78579, 0.78679, 0.78779, 0.78879, 0.78979, 0.79079, 0.79179,
0.79279, 0.79379, 0.79479, 0.7958, 0.7968, 0.7978, 0.7988, 0.7998, 0.8008, 0.8018, 0.8028, 0.8038, 0.8048, 0.80581, 0.80681, 0.80781, 0.80881, 0.80981, 0.81081, 0.81181, 0.81281, 0.81381, 0.81481, 0.81582,
0.81682, 0.81782, 0.81882, 0.81982, 0.82082, 0.82182, 0.82282, 0.82382, 0.82482, 0.82583, 0.82683, 0.82783, 0.82883, 0.82983, 0.83083, 0.83183, 0.83283, 0.83383, 0.83483, 0.83584, 0.83684, 0.83784, 0.83884, 0.83984,
0.84084, 0.84184, 0.84284, 0.84384, 0.84484, 0.84585, 0.84685, 0.84785, 0.84885, 0.84985, 0.85085, 0.85185, 0.85285, 0.85385, 0.85485, 0.85586, 0.85686, 0.85786, 0.85886, 0.85986, 0.86086, 0.86186, 0.86286, 0.86386,
0.86486, 0.86587, 0.86687, 0.86787, 0.86887, 0.86987, 0.87087, 0.87187, 0.87287, 0.87387, 0.87487, 0.87588, 0.87688, 0.87788, 0.87888, 0.87988, 0.88088, 0.88188, 0.88288, 0.88388, 0.88488, 0.88589, 0.88689, 0.88789,
0.88889, 0.88989, 0.89089, 0.89189, 0.89289, 0.89389, 0.89489, 0.8959, 0.8969, 0.8979, 0.8989, 0.8999, 0.9009, 0.9019, 0.9029, 0.9039, 0.9049, 0.90591, 0.90691, 0.90791, 0.90891, 0.90991, 0.91091, 0.91191,
0.91291, 0.91391, 0.91491, 0.91592, 0.91692, 0.91792, 0.91892, 0.91992, 0.92092, 0.92192, 0.92292, 0.92392, 0.92492, 0.92593, 0.92693, 0.92793, 0.92893, 0.92993, 0.93093, 0.93193, 0.93293, 0.93393, 0.93493, 0.93594,
0.93694, 0.93794, 0.93894, 0.93994, 0.94094, 0.94194, 0.94294, 0.94394, 0.94494, 0.94595, 0.94695, 0.94795, 0.94895, 0.94995, 0.95095, 0.95195, 0.95295, 0.95395, 0.95495, 0.95596, 0.95696, 0.95796, 0.95896, 0.95996,
0.96096, 0.96196, 0.96296, 0.96396, 0.96496, 0.96597, 0.96697, 0.96797, 0.96897, 0.96997, 0.97097, 0.97197, 0.97297, 0.97397, 0.97497, 0.97598, 0.97698, 0.97798, 0.97898, 0.97998, 0.98098, 0.98198, 0.98298, 0.98398,
0.98498, 0.98599, 0.98699, 0.98799, 0.98899, 0.98999, 0.99099, 0.99199, 0.99299, 0.99399, 0.99499, 0.996, 0.997, 0.998, 0.999, 1]), array([[ 1, 1, 0.98876, ..., 0, 0, 0],
[ 1, 1, 0.98901, ..., 0, 0, 0],
[ 0.98824, 0.98824, 0.98824, ..., 0, 0, 0],
[ 0.98246, 0.98246, 0.96491, ..., 0, 0, 0]]), 'Confidence', 'Recall']]
fitness: 0.8210046188823464
keys: ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
maps: array([ 0.84763, 0.7607, 0.86657, 0.75124])
names: {0: 'buffalo', 1: 'elephant', 2: 'rhino', 3: 'zebra'}
plot: True
results_dict: {'metrics/precision(B)': 0.9382732716357025, 'metrics/recall(B)': 0.8850679729846709, 'metrics/mAP50(B)': 0.9512262768368661, 'metrics/mAP50-95(B)': 0.8065355457762887, 'fitness': 0.8210046188823464}
save_dir: PosixPath('runs/detect/train')
speed: {'preprocess': 0.06935741778256165, 'inference': 1.2025215110689815, 'loss': 0.00013822236926191382, 'postprocess': 0.6716766088114431}
task: 'detect'
In [6]:
better_model = YOLO("runs/detect/train/weights/best.pt")
import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
from ultralytics import YOLO
# Function to draw predicted bounding boxes
def draw_predictions(image_path):
"""Runs YOLOv11 inference and draws bounding boxes on the image."""
# Load image
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Run YOLOv11 model on the image
results = better_model(image_rgb)
# Draw bounding boxes
for result in results:
boxes = result.boxes.xyxy # Bounding boxes (x1, y1, x2, y2)
scores = result.boxes.conf # Confidence scores
labels = result.boxes.cls # Class labels
for i, box in enumerate(boxes):
x1, y1, x2, y2 = map(int, box) # Convert to integers
label = model.names[int(labels[i])]
score = scores[i]
# Draw bounding box
cv2.rectangle(image_rgb, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.putText(image_rgb, f"{label} {score:.2f}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
return image_rgb # Return processed image
# Plot 6 images with YOLOv11 predictions
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
for ax, image_file in zip(axes.flatten(), selected_images):
image_path = os.path.join(image_folder, image_file)
# Process image with YOLOv11
predicted_image = draw_predictions(image_path)
# Display image
ax.imshow(predicted_image)
ax.set_title(f"Predictions for {image_file}")
ax.axis("off")
plt.tight_layout()
plt.show()
0: 448x640 2 elephants, 8.4ms Speed: 0.6ms preprocess, 8.4ms inference, 0.7ms postprocess per image at shape (1, 3, 448, 640) 0: 416x640 1 zebra, 8.7ms Speed: 0.8ms preprocess, 8.7ms inference, 0.6ms postprocess per image at shape (1, 3, 416, 640) 0: 480x640 1 rhino, 8.5ms Speed: 1.0ms preprocess, 8.5ms inference, 0.6ms postprocess per image at shape (1, 3, 480, 640) 0: 512x640 2 zebras, 8.5ms Speed: 0.9ms preprocess, 8.5ms inference, 0.6ms postprocess per image at shape (1, 3, 512, 640) 0: 448x640 1 rhino, 7.5ms Speed: 0.8ms preprocess, 7.5ms inference, 0.6ms postprocess per image at shape (1, 3, 448, 640) 0: 448x640 1 buffalo, 7.1ms Speed: 0.9ms preprocess, 7.1ms inference, 0.6ms postprocess per image at shape (1, 3, 448, 640) 0: 448x640 1 zebra, 7.1ms Speed: 0.6ms preprocess, 7.1ms inference, 0.6ms postprocess per image at shape (1, 3, 448, 640) 0: 640x640 6 rhinos, 8.8ms Speed: 1.0ms preprocess, 8.8ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 640) 0: 640x448 1 elephant, 8.9ms Speed: 0.6ms preprocess, 8.9ms inference, 0.7ms postprocess per image at shape (1, 3, 640, 448)
Semantic Segmentation with a Custom U-Net¶
In [37]:
import pandas as pd
import numpy as np
import os
#Organize oxford pets into folder structures
np.random.seed(42)
# Get list of image files
image_folder = "data/oxford-iiit-pet/images"
image_files = [f for f in os.listdir(image_folder) if f.endswith(('.jpg', '.png'))]
# Create dataframe
df = pd.DataFrame(image_files, columns=['filename'])
# Add label column
df['label'] = df['filename'].apply(lambda x: 'cat' if x[0].isupper() else 'dog')
# Add set column
df['set'] = np.random.choice([1, 0], size=len(df), p=[0.5, 0.5])
print(df.head())
filename label set 0 german_shorthaired_184.jpg dog 1 1 Birman_120.jpg cat 0 2 great_pyrenees_20.jpg dog 0 3 samoyed_6.jpg dog 0 4 american_pit_bull_terrier_28.jpg dog 1
In [39]:
import os
# Define the base path
base_path = "data/oxford_pets"
# Define the folder structure
folders = [
"train/cats", "train/dogs",
"valid/cats", "valid/dogs",
"train_trimaps/cats", "train_trimaps/dogs",
"valid_trimaps/cats", "valid_trimaps/dogs"
]
# Create the folders
for folder in folders:
os.makedirs(os.path.join(base_path, folder), exist_ok=True)
print("Folder structure created successfully.")
Folder structure created successfully.
In [40]:
import shutil
# Define the trimap folder
trimap_folder = "data/oxford-iiit-pet/annotations/trimaps"
# Function to copy files to the appropriate folder
def copy_files(row):
image_file = row['filename']
label = row['label']
set_type = 'train' if row['set'] == 1 else 'valid'
# Define source and destination paths
image_src = os.path.join(image_folder, image_file)
trimap_src = os.path.join(trimap_folder, image_file.replace('.jpg', '.png'))
if os.path.exists(trimap_src):
image_dst = os.path.join(base_path, f"{set_type}/{label}s", image_file)
trimap_dst = os.path.join(base_path, f"{set_type}_trimaps/{label}s", image_file.replace('.jpg', '.png'))
# Copy image and trimap
shutil.copy(image_src, image_dst)
shutil.copy(trimap_src, trimap_dst)
# Apply the function to each row in the dataframe
df.apply(copy_files, axis=1)
print("Files copied successfully.")
Files copied successfully.
In [41]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
# Define class mapping
CLASS_MAP = {"cats": 0, "dogs": 1} # 0 for cats, 1 for dogs
# Define a consistent image size
IMAGE_SIZE = (256, 256)
class OxfordPetsDataset(Dataset):
def __init__(self, image_dir, trimap_dir, transform=None):
self.image_dir = image_dir
self.trimap_dir = trimap_dir
self.transform = transform
# List all image files
self.image_files = sorted(os.listdir(image_dir))
# Define transformations to ensure consistent size for trimaps (without normalization)
self.trimap_transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE, interpolation=Image.NEAREST), # Resize while preserving labels
])
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
# Get filename
image_filename = self.image_files[idx]
image_path = os.path.join(self.image_dir, image_filename)
trimap_path = os.path.join(self.trimap_dir, image_filename.replace(".jpg", ".png"))
# Load image and trimap
image = Image.open(image_path).convert("RGB")
trimap = Image.open(trimap_path).convert("L") # Trimap is grayscale
# Resize trimap to (256, 256)
trimap = self.trimap_transform(trimap)
trimap = np.array(trimap, dtype=np.uint8) # Ensure integer encoding
# Get class label from folder name (cats or dogs)
class_name = os.path.basename(os.path.dirname(image_path))
class_label = CLASS_MAP[class_name] # 0 for cats, 1 for dogs
# Modify trimap coding (3-class)
new_trimap = np.zeros_like(trimap, dtype=np.uint8) # Initialize new trimap
new_trimap[trimap == 1] = 1 # Outline (both cats & dogs)
new_trimap[trimap == 2] = 2 # Object (both cats & dogs)
# Convert trimap to LongTensor before returning
new_trimap = torch.tensor(new_trimap, dtype=torch.long) # Ensure it's stored correctly
# Apply transformations to image (but NOT trimap)
if self.transform:
image = self.transform(image)
return image, new_trimap, torch.tensor(class_label, dtype=torch.long)
In [42]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
# Define dataset paths for images and trimaps
train_cats_images = "data/oxford_pets/train/cats"
train_dogs_images = "data/oxford_pets/train/dogs"
train_cats_trimaps = "data/oxford_pets/train_trimaps/cats"
train_dogs_trimaps = "data/oxford_pets/train_trimaps/dogs"
valid_cats_images = "data/oxford_pets/valid/cats"
valid_dogs_images = "data/oxford_pets/valid/dogs"
valid_cats_trimaps = "data/oxford_pets/valid_trimaps/cats"
valid_dogs_trimaps = "data/oxford_pets/valid_trimaps/dogs"
In [43]:
# Define transformations (apply only to images)
data_transforms = transforms.Compose([
transforms.Resize((256, 256)), # Ensure consistent image size
transforms.ToTensor(), # Convert image to PyTorch tensor (normalizes to [0,1])
])
# Create Dataset Instances
train_cats_dataset = OxfordPetsDataset(train_cats_images, train_cats_trimaps, transform=data_transforms)
train_dogs_dataset = OxfordPetsDataset(train_dogs_images, train_dogs_trimaps, transform=data_transforms)
valid_cats_dataset = OxfordPetsDataset(valid_cats_images, valid_cats_trimaps, transform=data_transforms)
valid_dogs_dataset = OxfordPetsDataset(valid_dogs_images, valid_dogs_trimaps, transform=data_transforms)
# Merge datasets
from torch.utils.data import ConcatDataset
train_dataset = ConcatDataset([train_cats_dataset, train_dogs_dataset])
valid_dataset = ConcatDataset([valid_cats_dataset, valid_dogs_dataset])
# Create DataLoaders
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
print(f"Train set: {len(train_loader.dataset)} samples")
print(f"Valid set: {len(valid_loader.dataset)} samples")
Train set: 3723 samples Valid set: 3667 samples
In [44]:
import matplotlib.pyplot as plt
import random
import numpy as np
# Select a random index from the validation dataset
random_idx = random.randint(0, len(valid_dataset) - 1)
# Get the image, trimap, and class label
image, trimap, class_label = valid_dataset[random_idx]
# Convert image tensor to NumPy format for visualization
img_np = image.permute(1, 2, 0).numpy() # Convert from (C, H, W) to (H, W, C)
# Convert trimap tensor to NumPy
trimap_np = trimap.cpu().numpy()
np.unique(trimap_np)
Out[44]:
array([0, 1, 2])
In [45]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
# Define class-to-color mapping for trimap
trimap_colors = {
0: (0, 0, 0), # Background - Black
1: (0, 0, 1), # Dog Outline - Blue
2: (0, 1, 0), # Dog Object - Green
3: (1, 0, 0), # Cat Outline - Red
4: (1, 1, 0), # Cat Object - Yellow
}
# Create a ListedColormap for visualization
cmap = mcolors.ListedColormap([trimap_colors[i] for i in range(len(trimap_colors))])
# Get a batch of training data
images, trimaps, labels = next(iter(train_loader))
# Show the first 4 images and trimaps
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(4):
# Convert tensor image to NumPy format
img_np = images[i].permute(1, 2, 0).cpu().numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) # Normalize for display
# Convert trimap tensor to NumPy
trimap_np = trimaps[i].cpu().numpy()
# Image display
axes[0, i].imshow(img_np)
axes[0, i].set_title(f"Class: {'Dog' if labels[i] == 1 else 'Cat'}")
axes[0, i].axis("off")
# Trimap display with custom colors
axes[1, i].imshow(trimap_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
axes[1, i].set_title("Trimap (Labeled)")
axes[1, i].axis("off")
# Create legend for the trimap colors
fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])
cbar = plt.colorbar(mappable=plt.cm.ScalarMappable(cmap=cmap), cax=cbar_ax, ticks=np.arange(len(trimap_colors)))
cbar.ax.set_yticklabels(["Background", "Dog Outline", "Dog", "Cat Outline", "Cat"]) # Custom labels
plt.tight_layout()
plt.show()
/tmp/ipykernel_380213/3389472999.py:47: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect. plt.tight_layout()
In [46]:
import torch
import torch.nn as nn
import torchvision.models as models
class UNetMobileNet(nn.Module):
def __init__(self, num_classes=5, pretrained=True, freeze_encoder=True, use_skip_connections=True):
super(UNetMobileNet, self).__init__()
self.use_skip_connections = use_skip_connections # Toggle skip connections
# Load MobileNetV2 as encoder
mobilenet = models.mobilenet_v2(pretrained=pretrained)
encoder_layers = list(mobilenet.features.children())
# Encoder blocks
self.encoder1 = nn.Sequential(*encoder_layers[:2]) # 128 × 128
self.encoder2 = nn.Sequential(*encoder_layers[2:4]) # 64 × 64
self.encoder3 = nn.Sequential(*encoder_layers[4:7]) # 32 × 32
self.encoder4 = nn.Sequential(*encoder_layers[7:14]) # 16 × 16
self.encoder5 = nn.Sequential(*encoder_layers[14:]) # 8 × 8
# Freeze encoder if specified
if freeze_encoder:
for param in self.encoder1.parameters():
param.requires_grad = False
for param in self.encoder2.parameters():
param.requires_grad = False
for param in self.encoder3.parameters():
param.requires_grad = False
for param in self.encoder4.parameters():
param.requires_grad = False
for param in self.encoder5.parameters():
param.requires_grad = False
# Bottleneck layer
self.bottleneck = nn.Conv2d(1280, 512, kernel_size=1)
# Decoder (Transposed Convolutions for upsampling)
self.decoder4 = self._upsample(512, 96) # 16 × 16
self.decoder3 = self._upsample(96, 32) # 32 × 32
self.decoder2 = self._upsample(32, 24) # 64 × 64
self.decoder1 = self._upsample(24, 16) # 128 × 128
# **Extra upsampling to reach 256 × 256**
self.final_up = self._upsample(16, 16) # 256 × 256
self.final_conv = nn.Conv2d(16, num_classes, kernel_size=1) # Final segmentation layer
def _upsample(self, in_channels, out_channels):
"""Helper function to create an upsampling block using transposed convolutions."""
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
# Encoder
e1 = self.encoder1(x) # 128 × 128
e2 = self.encoder2(e1) # 64 × 64
e3 = self.encoder3(e2) # 32 × 32
e4 = self.encoder4(e3) # 16 × 16
e5 = self.encoder5(e4) # 8 × 8
# Bottleneck layer
b = self.bottleneck(e5)
# Decoder with optional skip connections
d4 = self.decoder4(b) + e4 if self.use_skip_connections else self.decoder4(b) # 16 × 16
d3 = self.decoder3(d4) + e3 if self.use_skip_connections else self.decoder3(d4) # 32 × 32
d2 = self.decoder2(d3) + e2 if self.use_skip_connections else self.decoder2(d3) # 64 × 64
d1 = self.decoder1(d2) + e1 if self.use_skip_connections else self.decoder1(d2) # 128 × 128
d0 = self.final_up(d1) # **Extra upsampling to reach 256 × 256**
# Final segmentation map
return self.final_conv(d0) # Shape: (batch_size, num_classes, 256, 256)
In [47]:
import networkx as nx
import matplotlib.pyplot as plt
def visualize_unet_with_correct_skips(use_skip_connections=True):
"""
Creates a U-Net architecture visualization with input/output sizes and channels at each layer.
Args:
use_skip_connections (bool): Whether to show skip connections.
Returns:
fig: Matplotlib figure to be displayed inline in Jupyter Notebook.
"""
G = nx.DiGraph()
# Define encoder layers with (Height x Width, Channels)
encoder_layers = {
"Input (256x256, 3)": (0, 5),
"Enc1 (128x128, 16)": (1, 4),
"Enc2 (64x64, 24)": (2, 3),
"Enc3 (32x32, 32)": (3, 2),
"Enc4 (16x16, 96)": (4, 1),
"Bottleneck (8x8, 512)": (5, 0)
}
# Define decoder layers with corresponding output sizes and channels
decoder_layers = {
"Dec4 (16x16, 96)": (6, 1),
"Dec3 (32x32, 32)": (7, 2),
"Dec2 (64x64, 24)": (8, 3),
"Dec1 (128x128, 16)": (9, 4),
"Final Up (256x256, 16)": (10, 5),
"Output (256x256, 5)": (11, 5.5)
}
# Add encoder and decoder nodes
for layer, pos in {**encoder_layers, **decoder_layers}.items():
G.add_node(layer, pos=pos)
# Connect encoder layers sequentially
encoder_keys = list(encoder_layers.keys())
for i in range(len(encoder_keys) - 1):
G.add_edge(encoder_keys[i], encoder_keys[i + 1])
# Connect bottleneck to decoder
G.add_edge("Bottleneck (8x8, 512)", "Dec4 (16x16, 96)")
# Connect decoder layers sequentially
decoder_keys = list(decoder_layers.keys())
for i in range(len(decoder_keys) - 2):
G.add_edge(decoder_keys[i], decoder_keys[i + 1])
# Connect final output
G.add_edge("Final Up (256x256, 16)", "Output (256x256, 5)")
# Add **correct** skip connections if enabled
if use_skip_connections:
skip_connections = {
"Enc1 (128x128, 16)": "Dec1 (128x128, 16)",
"Enc2 (64x64, 24)": "Dec2 (64x64, 24)",
"Enc3 (32x32, 32)": "Dec3 (32x32, 32)",
"Enc4 (16x16, 96)": "Dec4 (16x16, 96)"
}
for enc, dec in skip_connections.items():
G.add_edge(enc, dec, color="red", style="dashed")
# Extract positions for visualization
pos = nx.get_node_attributes(G, "pos")
# Create a Matplotlib figure
fig, ax = plt.subplots(figsize=(12, 6))
edges = G.edges()
# Color edges differently for skip connections
edge_colors = ["red" if G[u][v].get("color") == "red" else "black" for u, v in edges]
edge_styles = ["dashed" if G[u][v].get("style") == "dashed" else "solid" for u, v in edges]
# Draw nodes
nx.draw(G, pos, with_labels=True, node_color="lightblue", node_size=2500, font_size=8, font_weight="bold", ax=ax)
# Draw normal and skip edges separately to allow different styles
for i, (u, v) in enumerate(edges):
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=edge_colors[i], style=edge_styles[i], width=2, ax=ax)
ax.set_title(f"U-Net Architecture {'(With Skip Connections)' if use_skip_connections else '(Without Skip Connections)'}")
return fig
# Display the images inline in Jupyter Notebook
fig1 = visualize_unet_with_correct_skips(use_skip_connections=True) # With Skip Connections
fig2 = visualize_unet_with_correct_skips(use_skip_connections=False) # Without Skip Connections
# Show figures in Jupyter Notebook
plt.show()
In [48]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
class UNetLightning(pl.LightningModule):
def __init__(self, num_classes=5, lr=1e-3, freeze_encoder=True, use_skip_connections=True):
super(UNetLightning, self).__init__()
# Load U-Net with correct output shape (256x256)
self.model = UNetMobileNet(
num_classes=num_classes,
freeze_encoder=freeze_encoder,
use_skip_connections=use_skip_connections
)
# Loss function (Cross-entropy for multi-class segmentation)
self.criterion = nn.CrossEntropyLoss()
# IoU (Intersection over Union) metric
self.iou = torchmetrics.JaccardIndex(task="multiclass", num_classes=num_classes)
# Learning rate
self.lr = lr
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
""" Training step: Computes loss and IoU for segmentation. """
images, trimaps, _ = batch # Extract inputs and ground truth masks
logits = self.model(images) # Forward pass
loss = self.criterion(logits, trimaps.squeeze(1).long()) # Squeeze trimap to shape (batch_size, height, width)
# Compute IoU (Jaccard Index)
iou = self.iou(torch.argmax(logits, dim=1), trimaps.squeeze(1).long())
# Log loss and IoU for training
self.log("train_loss", loss, prog_bar=True)
self.log("train_iou", iou, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
""" Validation step: Evaluates loss and IoU on the validation set. """
images, trimaps, _ = batch
logits = self.model(images)
loss = self.criterion(logits, trimaps.squeeze(1).long()) # Squeeze trimap to shape (batch_size, height, width)
# Compute IoU
iou = self.iou(torch.argmax(logits, dim=1), trimaps.squeeze(1).long())
# Log metrics for validation
self.log("val_loss", loss, prog_bar=True)
self.log("val_iou", iou, prog_bar=True)
return loss
def predict_step(self, batch, batch_idx):
""" Prediction step: Runs inference on new images. """
images, _, _ = batch # We only need images for prediction
logits = self.model(images) # Forward pass
preds = torch.argmax(logits, dim=1) # Get predicted class labels
return preds # Return predicted segmentation masks
def configure_optimizers(self):
""" Optimizer and Learning Rate Scheduler """
optimizer = optim.Adam(self.parameters(), lr=self.lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3)
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
In [49]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger
# Set up logging and early stopping
csv_logger = CSVLogger(save_dir='logs/', name='UNetNoSkips', version="")
early_stop_callback = EarlyStopping(monitor='val_loss', patience=25, verbose=True, mode="min")
# Create the model instance
model = UNetLightning(use_skip_connections = False)
# Assume train_loader and val_loader are defined DataLoaders
trainer = pl.Trainer(
max_epochs=10,
logger=csv_logger,
callbacks=[early_stop_callback]
)
trainer.fit(model, train_loader, valid_loader)
# Save the final model state
trainer.save_checkpoint('logs/UNetNoSkips/final_model.ckpt')
/home/kmcalist/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /home/kmcalist/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /home/kmcalist/.local/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py:269: Experiment logs directory logs/UNetNoSkips/ exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved! /home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory logs/UNetNoSkips/checkpoints exists and is not empty. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params | Mode ------------------------------------------------------------- 0 | model | UNetMobileNet | 3.8 M | train 1 | criterion | CrossEntropyLoss | 0 | train 2 | iou | MulticlassJaccardIndex | 0 | train ------------------------------------------------------------- 1.6 M Trainable params 2.2 M Non-trainable params 3.8 M Total params 15.361 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance. /home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance. /home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (15) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved. New best score: 1.304
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.245 >= min_delta = 0.0. New best score: 1.059
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.172 >= min_delta = 0.0. New best score: 0.887
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.138 >= min_delta = 0.0. New best score: 0.749
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.084 >= min_delta = 0.0. New best score: 0.665
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.132 >= min_delta = 0.0. New best score: 0.533
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.046 >= min_delta = 0.0. New best score: 0.487
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.479
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.033 >= min_delta = 0.0. New best score: 0.446 `Trainer.fit` stopped: `max_epochs=10` reached.
In [50]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger
# Set up logging and early stopping
csv_logger = CSVLogger(save_dir='logs/', name='UNetSkips', version="")
early_stop_callback = EarlyStopping(monitor='val_loss', patience=25, verbose=True, mode="min")
# Create the model instance
model = UNetLightning(use_skip_connections = True)
# Assume train_loader and val_loader are defined DataLoaders
trainer = pl.Trainer(
max_epochs=5,
logger=csv_logger,
callbacks=[early_stop_callback]
)
trainer.fit(model, train_loader, valid_loader)
# Save the final model state
trainer.save_checkpoint('logs/UNetSkips/final_model.ckpt')
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /home/kmcalist/.local/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py:269: Experiment logs directory logs/UNetSkips/ exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved! /home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory logs/UNetSkips/checkpoints exists and is not empty. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params | Mode ------------------------------------------------------------- 0 | model | UNetMobileNet | 3.8 M | train 1 | criterion | CrossEntropyLoss | 0 | train 2 | iou | MulticlassJaccardIndex | 0 | train ------------------------------------------------------------- 1.6 M Trainable params 2.2 M Non-trainable params 3.8 M Total params 15.361 Total estimated model params size (MB)
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved. New best score: 1.336
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.289 >= min_delta = 0.0. New best score: 1.047
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.035 >= min_delta = 0.0. New best score: 1.012
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.029 >= min_delta = 0.0. New best score: 0.984
Validation: | | 0/? [00:00<?, ?it/s]
Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 0.964 `Trainer.fit` stopped: `max_epochs=5` reached.
In [52]:
import torch
import pytorch_lightning as pl
# Load trained models
model_with_skips = UNetLightning.load_from_checkpoint(
'logs/UNetSkips/final_model.ckpt',
use_skip_connections=True # Ensure it matches training
)
model_without_skips = UNetLightning.load_from_checkpoint(
'logs/UNetNoSkips/final_model.ckpt',
use_skip_connections=False # Ensure it matches training
)
# Set models to evaluation mode
model_with_skips.eval()
model_without_skips.eval()
print("done")
done
/home/kmcalist/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /home/kmcalist/.local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg)
In [53]:
import random
# Get the validation dataset from the DataLoader
valid_dataset = valid_loader.dataset
# Randomly select 10 unique indices
random_indices = random.sample(range(len(valid_dataset)), 10)
# Extract the corresponding images (without labels)
random_samples = [valid_dataset[i] for i in random_indices]
# Convert images to batch format for model prediction
images = torch.stack([sample[0] for sample in random_samples]) # Image tensors
trimaps = torch.stack([sample[1] for sample in random_samples]) # Ground truth trimaps
classes = torch.stack([sample[2] for sample in random_samples])
In [54]:
# Create DataLoader for prediction
pred_loader = torch.utils.data.DataLoader(list(zip(images, trimaps, classes)), batch_size=10)
In [55]:
# Define Lightning trainer (No training, just prediction mode)
trainer = pl.Trainer(accelerator="gpu" if torch.cuda.is_available() else "cpu")
# Run predictions with trainer.predict()
preds_with_skips = trainer.predict(model_with_skips, dataloaders=pred_loader)
preds_without_skips = trainer.predict(model_without_skips, dataloaders=pred_loader)
# Convert list of batch tensors to a single tensor
preds_with_skips = torch.cat(preds_with_skips, dim=0)
preds_without_skips = torch.cat(preds_without_skips, dim=0)
# Convert logits to class labels
preds_with_skips = torch.argmax(preds_with_skips, dim=1).cpu().numpy()
preds_without_skips = torch.argmax(preds_without_skips, dim=1).cpu().numpy()
trimaps = trimaps.squeeze(1).cpu().numpy()
GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs 2025-03-04 14:16:52.935266: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2025-03-04 14:16:52.935290: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2025-03-04 14:16:52.936145: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2025-03-04 14:16:52.940029: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2025-03-04 14:16:53.520495: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] /home/kmcalist/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
Predicting: | | 0/? [00:00<?, ?it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting: | | 0/? [00:00<?, ?it/s]
In [59]:
preds_without_skips
Out[59]:
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]])
In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
# Define class-to-color mapping for trimap
trimap_colors = {
0: (0, 0, 0), # Background - Black
1: (1, 0, 0), # Outline - Red
2: (0, 1, 0), # Object - Green
}
# Create a ListedColormap for visualization
cmap = mcolors.ListedColormap([trimap_colors[i] for i in range(len(trimap_colors))])
# Show the first 10 images, ground truth trimaps, and predicted trimaps
fig, axes = plt.subplots(10, 4, figsize=(12, 30))
axes[0, 0].set_title("Input Image")
axes[0, 1].set_title("Ground Truth Trimap")
axes[0, 2].set_title("Prediction (With Skips)")
axes[0, 3].set_title("Prediction (No Skips)")
for i in range(10):
# Convert image tensor to NumPy format
img_np = images[i].permute(1, 2, 0).cpu().numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min()) # Normalize for display
# Convert trimap tensors to NumPy
trimap_np = trimaps[i]
pred_skips_np = preds_with_skips[i]
pred_no_skips_np = preds_without_skips[i]
axes[i, 0].imshow(img_np)
axes[i, 0].axis("off")
axes[i, 1].imshow(trimap_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
axes[i, 1].axis("off")
axes[i, 2].imshow(pred_skips_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
axes[i, 2].axis("off")
axes[i, 3].imshow(pred_no_skips_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1)
axes[i, 3].axis("off")
plt.tight_layout()
plt.show()
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[56], line 38 35 axes[i, 1].imshow(trimap_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1) 36 axes[i, 1].axis("off") ---> 38 axes[i, 2].imshow(pred_skips_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1) 39 axes[i, 2].axis("off") 41 axes[i, 3].imshow(pred_no_skips_np, cmap=cmap, vmin=0, vmax=len(trimap_colors)-1) File ~/.local/lib/python3.10/site-packages/matplotlib/__init__.py:1521, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs) 1518 @functools.wraps(func) 1519 def inner(ax, *args, data=None, **kwargs): 1520 if data is None: -> 1521 return func( 1522 ax, 1523 *map(cbook.sanitize_sequence, args), 1524 **{k: cbook.sanitize_sequence(v) for k, v in kwargs.items()}) 1526 bound = new_sig.bind(ax, *args, **kwargs) 1527 auto_label = (bound.arguments.get(label_namer) 1528 or bound.kwargs.get(label_namer)) File ~/.local/lib/python3.10/site-packages/matplotlib/axes/_axes.py:5945, in Axes.imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, colorizer, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs) 5942 if aspect is not None: 5943 self.set_aspect(aspect) -> 5945 im.set_data(X) 5946 im.set_alpha(alpha) 5947 if im.get_clip_path() is None: 5948 # image does not already have clipping set, clip to Axes patch File ~/.local/lib/python3.10/site-packages/matplotlib/image.py:675, in _ImageBase.set_data(self, A) 673 if isinstance(A, PIL.Image.Image): 674 A = pil_to_array(A) # Needed e.g. to apply png palette. --> 675 self._A = self._normalize_image_array(A) 676 self._imcache = None 677 self.stale = True File ~/.local/lib/python3.10/site-packages/matplotlib/image.py:643, in _ImageBase._normalize_image_array(A) 641 A = A.squeeze(-1) # If just (M, N, 1), assume scalar and apply colormap. 642 if not (A.ndim == 2 or A.ndim == 3 and A.shape[-1] in [3, 4]): --> 643 raise TypeError(f"Invalid shape {A.shape} for image data") 644 if A.ndim == 3: 645 # If the input data has values outside the valid range (after 646 # normalisation), we issue a warning and then clip X to the bounds 647 # - otherwise casting wraps extreme values, hiding outliers and 648 # making reliable interpretation impossible. 649 high = 255 if np.issubdtype(A.dtype, np.integer) else 1 TypeError: Invalid shape (256,) for image data